diff options
Diffstat (limited to 'source')
45 files changed, 15713 insertions, 14991 deletions
diff --git a/source/compiler-core/slang-diagnostic-sink.cpp b/source/compiler-core/slang-diagnostic-sink.cpp index 28a98266a..de8bdf52c 100644 --- a/source/compiler-core/slang-diagnostic-sink.cpp +++ b/source/compiler-core/slang-diagnostic-sink.cpp @@ -828,4 +828,45 @@ DiagnosticsLookup::DiagnosticsLookup( add(diagnostics, diagnosticsCount); } +void outputExceptionDiagnostic( + const AbortCompilationException& exception, + DiagnosticSink& sink, + slang::IBlob** outDiagnostics) +{ + sink.diagnoseRaw(Severity::Error, exception.Message.getUnownedSlice()); + sink.getBlobIfNeeded(outDiagnostics); +} + +void outputExceptionDiagnostic( + const Exception& exception, + DiagnosticSink& sink, + slang::IBlob** outDiagnostics) +{ + try + { + sink.diagnoseRaw(Severity::Internal, exception.Message.getUnownedSlice()); + } + catch (const AbortCompilationException&) + { + // Catch and ignore the AbortCompilationException that diagnoseRaw throws + // for Internal severity to prevent exception leak from loadModule + } + sink.getBlobIfNeeded(outDiagnostics); +} + +void outputExceptionDiagnostic(DiagnosticSink& sink, slang::IBlob** outDiagnostics) +{ + try + { + sink.diagnoseRaw(Severity::Fatal, "An unknown exception occurred"); + } + catch (const AbortCompilationException&) + { + // Catch and ignore the AbortCompilationException that diagnoseRaw throws + // for Fatal severity to prevent exception leak from loadModule + } + sink.getBlobIfNeeded(outDiagnostics); +} + + } // namespace Slang diff --git a/source/compiler-core/slang-diagnostic-sink.h b/source/compiler-core/slang-diagnostic-sink.h index 2d60747d9..fe05d953b 100644 --- a/source/compiler-core/slang-diagnostic-sink.h +++ b/source/compiler-core/slang-diagnostic-sink.h @@ -393,6 +393,19 @@ protected: MemoryArena m_arena; }; + +void outputExceptionDiagnostic( + const AbortCompilationException& exception, + DiagnosticSink& sink, + slang::IBlob** outDiagnostics); + +void outputExceptionDiagnostic( + const Exception& exception, + DiagnosticSink& sink, + slang::IBlob** outDiagnostics); + +void outputExceptionDiagnostic(DiagnosticSink& sink, slang::IBlob** outDiagnostics); + } // namespace Slang #endif diff --git a/source/core/slang-type-convert-util.h b/source/core/slang-type-convert-util.h index d4c80abf8..45cb4b82c 100644 --- a/source/core/slang-type-convert-util.h +++ b/source/core/slang-type-convert-util.h @@ -1,6 +1,14 @@ #ifndef SLANG_CORE_TYPE_CONVERT_UTIL_H #define SLANG_CORE_TYPE_CONVERT_UTIL_H +// TODO: This file's name is not obvious for what it contains. +// Either the file should be renamed to be more obviously related +// to what it does, or (better yet) the functionality should be +// moved to reside in places that are more logically related +// to each of the given types. +// +// Also: this doesn't belong in `core` for a bunch of reasons. + #include "slang.h" namespace Slang diff --git a/source/core/slang-type-text-util.h b/source/core/slang-type-text-util.h index 684d109c3..075b9e2a0 100644 --- a/source/core/slang-type-text-util.h +++ b/source/core/slang-type-text-util.h @@ -1,6 +1,14 @@ #ifndef SLANG_CORE_TYPE_TEXT_UTIL_H #define SLANG_CORE_TYPE_TEXT_UTIL_H +// TODO: This file's name is not obvious for what it contains. +// Either the file should be renamed to be more obviously related +// to what it does, or (better yet) the functionality should be +// moved to reside in places that are more logically related +// to each of the given types. +// +// Also: this doesn't belong in `core` for a bunch of reasons. + #include "slang-array-view.h" #include "slang-name-value.h" #include "slang-string.h" diff --git a/source/slang/slang-base-type-info.cpp b/source/slang/slang-base-type-info.cpp new file mode 100644 index 000000000..9072e34e4 --- /dev/null +++ b/source/slang/slang-base-type-info.cpp @@ -0,0 +1,95 @@ +// slang-base-type-info.cpp +#include "slang-base-type-info.h" + +namespace Slang +{ + +/* static */ const BaseTypeInfo BaseTypeInfo::s_info[Index(BaseType::CountOf)] = { + {0, 0, uint8_t(BaseType::Void)}, + {uint8_t(sizeof(bool)), 0, uint8_t(BaseType::Bool)}, + {uint8_t(sizeof(int8_t)), + BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, + uint8_t(BaseType::Int8)}, + {uint8_t(sizeof(int16_t)), + BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, + uint8_t(BaseType::Int16)}, + {uint8_t(sizeof(int32_t)), + BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, + uint8_t(BaseType::Int)}, + {uint8_t(sizeof(int64_t)), + BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, + uint8_t(BaseType::Int64)}, + {uint8_t(sizeof(uint8_t)), BaseTypeInfo::Flag::Integer, uint8_t(BaseType::UInt8)}, + {uint8_t(sizeof(uint16_t)), BaseTypeInfo::Flag::Integer, uint8_t(BaseType::UInt16)}, + {uint8_t(sizeof(uint32_t)), BaseTypeInfo::Flag::Integer, uint8_t(BaseType::UInt)}, + {uint8_t(sizeof(uint64_t)), BaseTypeInfo::Flag::Integer, uint8_t(BaseType::UInt64)}, + {uint8_t(sizeof(uint16_t)), BaseTypeInfo::Flag::FloatingPoint, uint8_t(BaseType::Half)}, + {uint8_t(sizeof(float)), BaseTypeInfo::Flag::FloatingPoint, uint8_t(BaseType::Float)}, + {uint8_t(sizeof(double)), BaseTypeInfo::Flag::FloatingPoint, uint8_t(BaseType::Double)}, + {uint8_t(sizeof(char)), + BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, + uint8_t(BaseType::Char)}, + {uint8_t(sizeof(intptr_t)), + BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, + uint8_t(BaseType::IntPtr)}, + {uint8_t(sizeof(uintptr_t)), BaseTypeInfo::Flag::Integer, uint8_t(BaseType::UIntPtr)}, +}; + +/* static */ bool BaseTypeInfo::check() +{ + for (Index i = 0; i < SLANG_COUNT_OF(s_info); ++i) + { + if (s_info[i].baseType != i) + { + SLANG_ASSERT(!"Inconsistency between the s_info table and BaseInfo"); + return false; + } + } + return true; +} + +/* static */ UnownedStringSlice BaseTypeInfo::asText(BaseType baseType) +{ + switch (baseType) + { + case BaseType::Void: + return UnownedStringSlice::fromLiteral("void"); + case BaseType::Bool: + return UnownedStringSlice::fromLiteral("bool"); + case BaseType::Int8: + return UnownedStringSlice::fromLiteral("int8_t"); + case BaseType::Int16: + return UnownedStringSlice::fromLiteral("int16_t"); + case BaseType::Int: + return UnownedStringSlice::fromLiteral("int"); + case BaseType::Int64: + return UnownedStringSlice::fromLiteral("int64_t"); + case BaseType::UInt8: + return UnownedStringSlice::fromLiteral("uint8_t"); + case BaseType::UInt16: + return UnownedStringSlice::fromLiteral("uint16_t"); + case BaseType::UInt: + return UnownedStringSlice::fromLiteral("uint"); + case BaseType::UInt64: + return UnownedStringSlice::fromLiteral("uint64_t"); + case BaseType::Half: + return UnownedStringSlice::fromLiteral("half"); + case BaseType::Float: + return UnownedStringSlice::fromLiteral("float"); + case BaseType::Double: + return UnownedStringSlice::fromLiteral("double"); + case BaseType::Char: + return UnownedStringSlice::fromLiteral("char"); + case BaseType::IntPtr: + return UnownedStringSlice::fromLiteral("intptr_t"); + case BaseType::UIntPtr: + return UnownedStringSlice::fromLiteral("uintptr_t"); + default: + { + SLANG_ASSERT(!"Unknown basic type"); + return UnownedStringSlice(); + } + } +} + +} // namespace Slang diff --git a/source/slang/slang-base-type-info.h b/source/slang/slang-base-type-info.h new file mode 100644 index 000000000..4b96af18f --- /dev/null +++ b/source/slang/slang-base-type-info.h @@ -0,0 +1,50 @@ +// slang-base-type-info.h +#pragma once + +// +// This file defines the `BaseTypeInfo` type, which encodes +// information (such as size in bits) about the base types +// supported by the Slang language. That information is used +// for things like checking if a literal is in the representible +// range of a given type, and for determining the relative +// cost of implicit conversions between the base types. +// + +#include "../core/slang-basic.h" +#include "slang-type-system-shared.h" + +namespace Slang +{ + +// Information about BaseType that's useful for checking literals +struct BaseTypeInfo +{ + typedef uint8_t Flags; + struct Flag + { + enum Enum : Flags + { + Signed = 0x1, + FloatingPoint = 0x2, + Integer = 0x4, + }; + }; + + SLANG_FORCE_INLINE static const BaseTypeInfo& getInfo(BaseType baseType) + { + return s_info[Index(baseType)]; + } + + static UnownedStringSlice asText(BaseType baseType); + + uint8_t sizeInBytes; ///< Size of type in bytes + Flags flags; + uint8_t baseType; + + static bool check(); + +private: + static const BaseTypeInfo s_info[Index(BaseType::CountOf)]; +}; + +} // namespace Slang diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 3145c9454..a82278054 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -3261,4 +3261,19 @@ bool getExtensionTargetDeclList( ExtensionDecl* extDeclRef, ShortList<AggTypeDecl*>& targetDecls); +void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink); + +RefPtr<ComponentType> createUnspecializedGlobalComponentType( + FrontEndCompileRequest* compileRequest); + +RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType( + FrontEndCompileRequest* compileRequest, + List<RefPtr<ComponentType>>& outUnspecializedEntryPoints); + +RefPtr<ComponentType> createSpecializedGlobalComponentType(EndToEndCompileRequest* endToEndReq); + +RefPtr<ComponentType> createSpecializedGlobalAndEntryPointsComponentType( + EndToEndCompileRequest* endToEndReq, + List<RefPtr<ComponentType>>& outSpecializedEntryPoints); + } // namespace Slang diff --git a/source/slang/slang-code-gen.cpp b/source/slang/slang-code-gen.cpp new file mode 100644 index 000000000..cd47147e2 --- /dev/null +++ b/source/slang/slang-code-gen.cpp @@ -0,0 +1,1405 @@ +// slang-code-gen.cpp +#include "slang-code-gen.h" + +#include "../compiler-core/slang-slice-allocator.h" +#include "../core/slang-type-convert-util.h" +#include "../core/slang-type-text-util.h" +#include "slang-compiler.h" +#include "slang-emit-cuda.h" // for `CUDAExtensionTracker` +#include "slang-extension-tracker.h" // for `ShaderExtensionTracker` + +// TODO: The "artifact" system is a scourge. +#include "../compiler-core/slang-artifact-desc-util.h" +#include "../compiler-core/slang-artifact-impl.h" +#include "../compiler-core/slang-artifact-util.h" +#include "slang-artifact-output-util.h" + +namespace Slang +{ + +EndToEndCompileRequest* CodeGenContext::isPassThroughEnabled() +{ + auto endToEndReq = isEndToEndCompile(); + + // If there isn't an end-to-end compile going on, + // there can be no pass-through. + // + if (!endToEndReq) + return nullptr; + + // And if pass-through isn't set on that end-to-end compile, + // then we clearly areb't doing a pass-through compile. + // + if (endToEndReq->m_passThrough == PassThroughMode::None) + return nullptr; + + // If we have confirmed that pass-through compilation is going on, + // we return the end-to-end request, because it has all the + // relevant state that we need to implement pass-through mode. + // + return endToEndReq; +} + +/// If there is a pass-through compile going on, find the translation unit for the given entry +/// point. Assumes isPassThroughEnabled has already been called +TranslationUnitRequest* getPassThroughTranslationUnit( + EndToEndCompileRequest* endToEndReq, + Int entryPointIndex) +{ + SLANG_ASSERT(endToEndReq); + SLANG_ASSERT(endToEndReq->m_passThrough != PassThroughMode::None); + auto frontEndReq = endToEndReq->getFrontEndReq(); + auto entryPointReq = frontEndReq->getEntryPointReq(entryPointIndex); + auto translationUnit = entryPointReq->getTranslationUnit(); + return translationUnit; +} + +TranslationUnitRequest* CodeGenContext::findPassThroughTranslationUnit(Int entryPointIndex) +{ + if (auto endToEndReq = isPassThroughEnabled()) + return getPassThroughTranslationUnit(endToEndReq, entryPointIndex); + return nullptr; +} + +static void _appendCodeWithPath( + const UnownedStringSlice& filePath, + const UnownedStringSlice& fileContent, + StringBuilder& outCodeBuilder) +{ + outCodeBuilder << "#line 1 \""; + auto handler = StringEscapeUtil::getHandler(StringEscapeUtil::Style::Cpp); + handler->appendEscaped(filePath, outCodeBuilder); + outCodeBuilder << "\"\n"; + outCodeBuilder << fileContent << "\n"; +} + +#if SLANG_VC +// TODO(JS): This is a workaround +// In debug VS builds there is a warning on line about it being unreachable. +// for (auto entryPointIndex : getEntryPointIndices()) +// It's not clear how that could possibly be unreachable +// +// Note(tfoley): The diagnostic noted above arises because the `for` +// loop in question unconditionally exits on its first iteration. +// As a result the automatically-generated code for the "continue clause" +// (more or less the `operator++` on the iterator) is identified +// as unreachable code. +// +// The actual fix would be to make this code more explicit about its +// expectations. Either it expects there to be only a single entry +// point (in which case it should diagnose if that expectation is +// not met), or it just wants to query for the *first* entry point +// more explicitly, and include a comment explaining why that is valid. +// +#pragma warning(push) +#pragma warning(disable : 4702) +#endif +SlangResult CodeGenContext::emitEntryPointsSource(ComPtr<IArtifact>& outArtifact) +{ + outArtifact.setNull(); + + SLANG_RETURN_ON_FAIL(requireTranslationUnitSourceFiles()); + + auto endToEndReq = isPassThroughEnabled(); + if (endToEndReq) + { + for (auto entryPointIndex : getEntryPointIndices()) + { + auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndex); + SLANG_ASSERT(translationUnit); + + /// Make sure we have the source files + SLANG_RETURN_ON_FAIL(translationUnit->requireSourceFiles()); + + // Generate a string that includes the content of + // the source file(s), along with a line directive + // to ensure that we get reasonable messages + // from the downstream compiler when in pass-through + // mode. + + StringBuilder codeBuilder; + if (getTargetFormat() == CodeGenTarget::GLSL) + { + // Special case GLSL + int translationUnitCounter = 0; + for (auto sourceFile : translationUnit->getSourceFiles()) + { + int translationUnitIndex = translationUnitCounter++; + + // We want to output `#line` directives, but we need + // to skip this for the first file, since otherwise + // some GLSL implementations will get tripped up by + // not having the `#version` directive be the first + // thing in the file. + if (translationUnitIndex != 0) + { + codeBuilder << "#line 1 " << translationUnitIndex << "\n"; + } + codeBuilder << sourceFile->getContent() << "\n"; + } + } + else + { + for (auto sourceFile : translationUnit->getSourceFiles()) + { + _appendCodeWithPath( + sourceFile->getPathInfo().foundPath.getUnownedSlice(), + sourceFile->getContent(), + codeBuilder); + } + } + + auto artifact = + ArtifactUtil::createArtifactForCompileTarget(asExternal(getTargetFormat())); + artifact->addRepresentationUnknown(StringBlob::moveCreate(codeBuilder)); + + outArtifact.swap(artifact); + return SLANG_OK; + } + return SLANG_OK; + } + else + { + return emitEntryPointsSourceFromIR(outArtifact); + } +} +#if SLANG_VC +#pragma warning(pop) +#endif + +SlangResult CodeGenContext::emitPrecompiledDownstreamIR(ComPtr<IArtifact>& outArtifact) +{ + return _emitEntryPoints(outArtifact); +} + +static String _getDisplayPath(DiagnosticSink* sink, SourceFile* sourceFile) +{ + if (sink->isFlagSet(DiagnosticSink::Flag::VerbosePath)) + { + return sourceFile->calcVerbosePath(); + } + else + { + return sourceFile->getPathInfo().foundPath; + } +} + +String CodeGenContext::calcSourcePathForEntryPoints() +{ + String failureMode = "slang-generated"; + if (getEntryPointCount() != 1) + return failureMode; + auto entryPointIndex = getSingleEntryPointIndex(); + auto translationUnitRequest = findPassThroughTranslationUnit(entryPointIndex); + if (!translationUnitRequest) + return failureMode; + + const auto& sourceFiles = translationUnitRequest->getSourceFiles(); + + auto sink = getSink(); + + const Index numSourceFiles = sourceFiles.getCount(); + + switch (numSourceFiles) + { + case 0: + return "unknown"; + case 1: + return _getDisplayPath(sink, sourceFiles[0]); + default: + { + StringBuilder builder; + builder << _getDisplayPath(sink, sourceFiles[0]); + for (int i = 1; i < numSourceFiles; ++i) + { + builder << ";" << _getDisplayPath(sink, sourceFiles[i]); + } + return builder; + } + } +} + +static RefPtr<ExtensionTracker> _newExtensionTracker(CodeGenTarget target) +{ + switch (target) + { + case CodeGenTarget::PTX: + case CodeGenTarget::CUDASource: + { + return new CUDAExtensionTracker; + } + case CodeGenTarget::SPIRV: + case CodeGenTarget::GLSL: + case CodeGenTarget::WGSL: + case CodeGenTarget::WGSLSPIRV: + case CodeGenTarget::WGSLSPIRVAssembly: + { + return new ShaderExtensionTracker; + } + default: + return nullptr; + } +} + +static CodeGenTarget _getDefaultSourceForTarget(CodeGenTarget target) +{ + switch (target) + { + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::ShaderSharedLibrary: + { + return CodeGenTarget::CPPSource; + } + case CodeGenTarget::HostHostCallable: + case CodeGenTarget::HostExecutable: + case CodeGenTarget::HostSharedLibrary: + { + return CodeGenTarget::HostCPPSource; + } + case CodeGenTarget::PTX: + return CodeGenTarget::CUDASource; + case CodeGenTarget::DXBytecode: + return CodeGenTarget::HLSL; + case CodeGenTarget::DXIL: + return CodeGenTarget::HLSL; + case CodeGenTarget::SPIRV: + return CodeGenTarget::GLSL; + case CodeGenTarget::MetalLib: + return CodeGenTarget::Metal; + case CodeGenTarget::WGSLSPIRV: + return CodeGenTarget::WGSL; + default: + break; + } + return CodeGenTarget::Unknown; +} + +void trackGLSLTargetCaps(ShaderExtensionTracker* extensionTracker, CapabilitySet const& caps) +{ + for (auto& conjunctions : caps.getAtomSets()) + { + for (auto atom : conjunctions) + { + switch (asAtom(atom)) + { + default: + break; + + case CapabilityAtom::glsl_spirv_1_0: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 0)); + break; + case CapabilityAtom::glsl_spirv_1_1: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 1)); + break; + case CapabilityAtom::glsl_spirv_1_2: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 2)); + break; + case CapabilityAtom::glsl_spirv_1_3: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 3)); + break; + case CapabilityAtom::glsl_spirv_1_4: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 4)); + break; + case CapabilityAtom::glsl_spirv_1_5: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 5)); + break; + case CapabilityAtom::glsl_spirv_1_6: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 6)); + break; + } + } + } +} + +// True if it's best to use 'emitted' source for complication. For a downstream compiler +// that is not file based, this is always ok. +/// +/// If the downstream compiler is file system based, we may want to just use the file that was +/// passed to be compiled. That the downstream compiler can determine if it will then save the file +/// or not based on if it's a match - and generally there will not be a match with emitted source. +/// +/// This test is only used for pass through mode. +static bool _useEmittedSource( + IDownstreamCompiler* compiler, + TranslationUnitRequest* translationUnit) +{ + // We only bother if it's a file based compiler. + if (compiler->isFileBased()) + { + // It can only have *one* source file as otherwise we have to combine to make a new source + // file anyway + return translationUnit->getSourceArtifacts().getCount() != 1; + } + return true; +} + +static bool _shouldSetEntryPointName(TargetProgram* targetProgram) +{ + if (!isKhronosTarget(targetProgram->getTargetReq())) + return true; + if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::VulkanUseEntryPointName)) + return true; + return false; +} + +static bool _isCPUHostTarget(CodeGenTarget target) +{ + auto desc = ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)); + return desc.style == ArtifactStyle::Host; +} + +SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr<IArtifact>& outArtifact) +{ + outArtifact.setNull(); + + auto sink = getSink(); + auto session = getSession(); + + CodeGenTarget sourceTarget = CodeGenTarget::None; + SourceLanguage sourceLanguage = SourceLanguage::Unknown; + + auto target = getTargetFormat(); + RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target); + PassThroughMode compilerType; + + SliceAllocator allocator; + + if (auto endToEndReq = isPassThroughEnabled()) + { + compilerType = endToEndReq->m_passThrough; + } + else + { + // If we are not in pass through, lookup the default compiler for the emitted source type + + // Get the default source codegen type for a given target + sourceTarget = _getDefaultSourceForTarget(target); + compilerType = (PassThroughMode)session->getDownstreamCompilerForTransition( + (SlangCompileTarget)sourceTarget, + (SlangCompileTarget)target); + // We should have a downstream compiler set at this point + if (compilerType == PassThroughMode::None) + { + auto sourceName = TypeTextUtil::getCompileTargetName(SlangCompileTarget(sourceTarget)); + auto targetName = TypeTextUtil::getCompileTargetName(SlangCompileTarget(target)); + + sink->diagnose( + SourceLoc(), + Diagnostics::compilerNotDefinedForTransition, + sourceName, + targetName); + return SLANG_FAIL; + } + } + + SLANG_ASSERT(compilerType != PassThroughMode::None); + + // Get the required downstream compiler + IDownstreamCompiler* compiler = session->getOrLoadDownstreamCompiler(compilerType, sink); + if (!compiler) + { + auto compilerName = TypeTextUtil::getPassThroughAsHumanText((SlangPassThrough)compilerType); + sink->diagnose(SourceLoc(), Diagnostics::passThroughCompilerNotFound, compilerName); + return SLANG_FAIL; + } + + Dictionary<String, String> preprocessorDefinitions; + List<String> includePaths; + + typedef DownstreamCompileOptions CompileOptions; + CompileOptions options; + + List<DownstreamCompileOptions::CapabilityVersion> requiredCapabilityVersions; + List<String> compilerSpecificArguments; + List<ComPtr<IArtifact>> libraries; + List<String> libraryPaths; + + // Set compiler specific args + { + auto name = TypeTextUtil::getPassThroughName((SlangPassThrough)compilerType); + List<String> downstreamArgs = getTargetProgram()->getOptionSet().getDownstreamArgs(name); + for (const auto& arg : downstreamArgs) + { + // We special case some kinds of args, that can be handled directly + if (arg.startsWith("-I")) + { + // We handle the -I option, by just adding to the include paths + includePaths.add(arg.getUnownedSlice().tail(2)); + } + else + { + compilerSpecificArguments.add(arg); + } + } + } + + ComPtr<IArtifact> sourceArtifact; + + /* This is more convoluted than the other scenarios, because when we invoke C/C++ compiler we + would ideally like to use the original file. We want to do this because we want includes + relative to the source file to work, and for that to work most easily we want to use the + original file, if there is one */ + if (auto endToEndReq = isPassThroughEnabled()) + { + // If we are pass through, we may need to set extension tracker state. + if (ShaderExtensionTracker* glslTracker = as<ShaderExtensionTracker>(extensionTracker)) + { + trackGLSLTargetCaps(glslTracker, getTargetCaps()); + } + + auto translationUnit = + getPassThroughTranslationUnit(endToEndReq, getSingleEntryPointIndex()); + + // We are just passing thru, so it's whatever it originally was + sourceLanguage = translationUnit->sourceLanguage; + + // TODO(JS): This seems like a bit of a hack + // That if a pass-through is being performed and the source language is Slang + // no downstream compiler knows how to deal with that, so probably means 'HLSL' + sourceLanguage = + (sourceLanguage == SourceLanguage::Slang) ? SourceLanguage::HLSL : sourceLanguage; + sourceTarget = CodeGenTarget(TypeConvertUtil::getCompileTargetFromSourceLanguage( + (SlangSourceLanguage)sourceLanguage)); + + // If it's pass through we accumulate the preprocessor definitions. + for (const auto& define : + endToEndReq->getOptionSet().getArray(CompilerOptionName::MacroDefine)) + preprocessorDefinitions.add(define.stringValue, define.stringValue2); + for (const auto& define : translationUnit->preprocessorDefinitions) + preprocessorDefinitions.add(define); + + { + /* TODO(JS): Not totally clear what options should be set here. If we are using the pass + through - then using say the defines/includes all makes total sense. If we are + generating C++ code from slang, then should we really be using these values -> aren't + they what is being set for the *slang* source, not for the C++ generated code. That + being the case it implies that there needs to be a mechanism (if there isn't already) to + specify such information on a particular pass/pass through etc. + + On invoking DXC for example include paths do not appear to be set at all (even with + pass-through). + */ + + auto linkage = getLinkage(); + + // Add all the search paths + + const auto searchDirectories = linkage->getSearchDirectories(); + const SearchDirectoryList* searchList = &searchDirectories; + while (searchList) + { + for (const auto& searchDirectory : searchList->searchDirectories) + { + includePaths.add(searchDirectory.path); + } + searchList = searchList->parent; + } + } + + // If emitted source is required, emit and set the path + if (_useEmittedSource(compiler, translationUnit)) + { + CodeGenContext sourceCodeGenContext(this, sourceTarget, extensionTracker); + + SLANG_RETURN_ON_FAIL(sourceCodeGenContext.emitEntryPointsSource(sourceArtifact)); + + // If it's not file based we can set an appropriate path name, and it doesn't matter if + // it doesn't exist on the file system. We set the name to the path as this will be used + // for downstream reporting. + auto sourcePath = calcSourcePathForEntryPoints(); + sourceArtifact->setName(sourcePath.getBuffer()); + + sourceCodeGenContext.maybeDumpIntermediate(sourceArtifact); + } + else + { + // Special case if we have a single file, so that we pass the path, and the contents as + // is. + const auto& sourceArtifacts = translationUnit->getSourceArtifacts(); + SLANG_ASSERT(sourceArtifacts.getCount() == 1); + + sourceArtifact = sourceArtifacts[0]; + SLANG_ASSERT(sourceArtifact); + } + } + else + { + CodeGenContext sourceCodeGenContext(this, sourceTarget, extensionTracker); + + sourceCodeGenContext.removeAvailableInDownstreamIR = true; + + SLANG_RETURN_ON_FAIL(sourceCodeGenContext.emitEntryPointsSource(sourceArtifact)); + sourceCodeGenContext.maybeDumpIntermediate(sourceArtifact); + + sourceLanguage = (SourceLanguage)TypeConvertUtil::getSourceLanguageFromTarget( + (SlangCompileTarget)sourceTarget); + } + + if (sourceArtifact) + { + // Set the source artifacts + options.sourceArtifacts = makeSlice(sourceArtifact.readRef(), 1); + } + + // Add any preprocessor definitions associated with the linkage + { + // TODO(JS): This is somewhat arguable - should defines passed to Slang really be + // passed to downstream compilers? It does appear consistent with the behavior if + // there is an endToEndReq. + // + // That said it's very convenient and provides way to control aspects + // of downstream compilation. + + for (const auto& define : + getTargetProgram()->getOptionSet().getArray(CompilerOptionName::MacroDefine)) + { + preprocessorDefinitions.addIfNotExists(define.stringValue, define.stringValue2); + } + } + + + // If we have an extension tracker, we may need to set options such as SPIR-V version + // and CUDA Shader Model. + if (extensionTracker) + { + // Look for the version + if (auto cudaTracker = as<CUDAExtensionTracker>(extensionTracker)) + { + cudaTracker->finalize(); + + if (cudaTracker->m_smVersion.isSet()) + { + DownstreamCompileOptions::CapabilityVersion version; + version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::CUDASM; + version.version = cudaTracker->m_smVersion; + + requiredCapabilityVersions.add(version); + } + + if (cudaTracker->isBaseTypeRequired(BaseType::Half)) + { + options.flags |= CompileOptions::Flag::EnableFloat16; + } + } + else if (ShaderExtensionTracker* glslTracker = as<ShaderExtensionTracker>(extensionTracker)) + { + DownstreamCompileOptions::CapabilityVersion version; + version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::SPIRV; + version.version = glslTracker->getSPIRVVersion(); + + requiredCapabilityVersions.add(version); + } + } + + CapabilitySet targetCaps = getTargetCaps(); + for (auto atomSets : targetCaps.getAtomSets()) + { + for (auto atomVal : atomSets) + { + auto atom = CapabilityAtom(atomVal); + switch (atom) + { + default: + break; + +#define CASE(KIND, NAME, VERSION) \ + case CapabilityAtom::NAME: \ + requiredCapabilityVersions.add(DownstreamCompileOptions::CapabilityVersion{ \ + DownstreamCompileOptions::CapabilityVersion::Kind::KIND, \ + VERSION}); \ + break + + CASE(CUDASM, _cuda_sm_1_0, SemanticVersion(1, 0)); + CASE(CUDASM, _cuda_sm_2_0, SemanticVersion(2, 0)); + CASE(CUDASM, _cuda_sm_3_0, SemanticVersion(3, 0)); + CASE(CUDASM, _cuda_sm_4_0, SemanticVersion(4, 0)); + CASE(CUDASM, _cuda_sm_5_0, SemanticVersion(5, 0)); + CASE(CUDASM, _cuda_sm_6_0, SemanticVersion(6, 0)); + CASE(CUDASM, _cuda_sm_7_0, SemanticVersion(7, 0)); + CASE(CUDASM, _cuda_sm_8_0, SemanticVersion(8, 0)); + CASE(CUDASM, _cuda_sm_9_0, SemanticVersion(9, 0)); + +#undef CASE + } + } + } + + // Set the file sytem and source manager, as *may* be used by downstream compiler + options.fileSystemExt = getFileSystemExt(); + options.sourceManager = getSourceManager(); + + // Set the source type + options.sourceLanguage = SlangSourceLanguage(sourceLanguage); + + switch (target) + { + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::ShaderSharedLibrary: + // Disable exceptions and security checks + options.flags &= + ~(CompileOptions::Flag::EnableExceptionHandling | + CompileOptions::Flag::EnableSecurityChecks); + break; + } + + Profile profile; + + if (compilerType == PassThroughMode::Fxc || compilerType == PassThroughMode::Dxc || + compilerType == PassThroughMode::Glslang) + { + const auto entryPointIndices = getEntryPointIndices(); + auto targetReq = getTargetReq(); + + const auto entryPointIndicesCount = entryPointIndices.getCount(); + + // Whole program means + // * can have 0-N entry points + // * 'doesn't build into an executable/kernel' + // + // So in some sense it is a library + if (getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::GenerateWholeProgram)) + { + if (compilerType == PassThroughMode::Dxc) + { + // Can support no entry points on DXC because we can build libraries + profile = + Profile(getTargetProgram()->getOptionSet().getEnumOption<Profile::RawEnum>( + CompilerOptionName::Profile)); + } + else + { + auto downstreamCompilerName = + TypeTextUtil::getPassThroughName((SlangPassThrough)compilerType); + + sink->diagnose( + SourceLoc(), + Diagnostics::downstreamCompilerDoesntSupportWholeProgramCompilation, + downstreamCompilerName); + return SLANG_FAIL; + } + } + else if (entryPointIndicesCount == 1) + { + // All support a single entry point + const Index entryPointIndex = entryPointIndices[0]; + + auto entryPoint = getEntryPoint(entryPointIndex); + profile = getEffectiveProfile(entryPoint, targetReq); + + if (_shouldSetEntryPointName(getTargetProgram())) + { + options.entryPointName = allocator.allocate(getText(entryPoint->getName())); + auto entryPointNameOverride = + getProgram()->getEntryPointNameOverride(entryPointIndex); + if (entryPointNameOverride.getLength() != 0) + { + options.entryPointName = allocator.allocate(entryPointNameOverride); + } + } + } + else + { + // We only support a single entry point on this target + SLANG_ASSERT(!"Can only compile with a single entry point on this target"); + return SLANG_FAIL; + } + + options.stage = SlangStage(profile.getStage()); + + if (compilerType == PassThroughMode::Dxc) + { + // We will enable the flag to generate proper code for 16 - bit types + // by default, as long as the user is requesting a sufficiently + // high shader model. + // + // TODO: Need to check that this is safe to enable in all cases, + // or if it will make a shader demand hardware features that + // aren't always present. + // + // TODO: Ideally the dxc back-end should be passed some information + // on the "capabilities" that were used and/or requested in the code. + // + if (profile.getVersion() >= ProfileVersion::DX_6_2) + { + options.flags |= CompileOptions::Flag::EnableFloat16; + } + + // Set the matrix layout + options.matrixLayout = + (SlangMatrixLayoutMode)getTargetProgram()->getOptionSet().getMatrixLayoutMode(); + } + + // Set the profile + options.profileName = allocator.allocate(getHLSLProfileName(profile)); + } + + // If we aren't using LLVM 'host callable', we want downstream compile to produce a shared + // library + if (compilerType != PassThroughMode::LLVM && + ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)).kind == + ArtifactKind::HostCallable) + { + target = CodeGenTarget::ShaderSharedLibrary; + } + + if (!isPassThroughEnabled()) + { + if (_isCPUHostTarget(target)) + { + libraryPaths.add(Path::getParentDirectory(Path::getExecutablePath())); + libraryPaths.add( + Path::combine(Path::getParentDirectory(Path::getExecutablePath()), "../lib")); + + // Set up the library artifact + auto artifact = Artifact::create( + ArtifactDesc::make(ArtifactKind::Library, Artifact::Payload::HostCPU), + toSlice("slang-rt")); + + ComPtr<IOSFileArtifactRepresentation> fileRep(new OSFileArtifactRepresentation( + IOSFileArtifactRepresentation::Kind::NameOnly, + toSlice("slang-rt"), + nullptr)); + artifact->addRepresentation(fileRep); + + libraries.add(artifact); + } + } + + options.targetType = (SlangCompileTarget)target; + + // Need to configure for the compilation + + { + auto linkage = getLinkage(); + + switch (getTargetProgram()->getOptionSet().getEnumOption<OptimizationLevel>( + CompilerOptionName::Optimization)) + { + case OptimizationLevel::None: + options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::None; + break; + case OptimizationLevel::Default: + options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::Default; + break; + case OptimizationLevel::High: + options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::High; + break; + case OptimizationLevel::Maximal: + options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::Maximal; + break; + default: + SLANG_ASSERT(!"Unhandled optimization level"); + break; + } + + switch (getTargetProgram()->getOptionSet().getEnumOption<DebugInfoLevel>( + CompilerOptionName::DebugInformation)) + { + case DebugInfoLevel::None: + options.debugInfoType = DownstreamCompileOptions::DebugInfoType::None; + break; + case DebugInfoLevel::Minimal: + options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Minimal; + break; + + case DebugInfoLevel::Standard: + options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Standard; + break; + case DebugInfoLevel::Maximal: + options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Maximal; + break; + default: + SLANG_ASSERT(!"Unhandled debug level"); + break; + } + + switch (getTargetProgram()->getOptionSet().getEnumOption<FloatingPointMode>( + CompilerOptionName::FloatingPointMode)) + { + case FloatingPointMode::Default: + options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Default; + break; + case FloatingPointMode::Precise: + options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Precise; + break; + case FloatingPointMode::Fast: + options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Fast; + break; + default: + SLANG_ASSERT(!"Unhandled floating point mode"); + } + + if (getTargetProgram()->getOptionSet().hasOption(CompilerOptionName::DenormalModeFp16)) + { + switch (getTargetProgram()->getOptionSet().getEnumOption<FloatingPointDenormalMode>( + CompilerOptionName::DenormalModeFp16)) + { + case FloatingPointDenormalMode::Any: + options.denormalModeFp16 = DownstreamCompileOptions::FloatingPointDenormalMode::Any; + break; + case FloatingPointDenormalMode::Preserve: + options.denormalModeFp16 = + DownstreamCompileOptions::FloatingPointDenormalMode::Preserve; + break; + case FloatingPointDenormalMode::FlushToZero: + options.denormalModeFp16 = + DownstreamCompileOptions::FloatingPointDenormalMode::FlushToZero; + break; + default: + SLANG_ASSERT(!"Unhandled fp16 denormal handling mode"); + } + } + + if (getTargetProgram()->getOptionSet().hasOption(CompilerOptionName::DenormalModeFp32)) + { + switch (getTargetProgram()->getOptionSet().getEnumOption<FloatingPointDenormalMode>( + CompilerOptionName::DenormalModeFp32)) + { + case FloatingPointDenormalMode::Any: + options.denormalModeFp32 = DownstreamCompileOptions::FloatingPointDenormalMode::Any; + break; + case FloatingPointDenormalMode::Preserve: + options.denormalModeFp32 = + DownstreamCompileOptions::FloatingPointDenormalMode::Preserve; + break; + case FloatingPointDenormalMode::FlushToZero: + options.denormalModeFp32 = + DownstreamCompileOptions::FloatingPointDenormalMode::FlushToZero; + break; + default: + SLANG_ASSERT(!"Unhandled fp32 denormal handling mode"); + } + } + + if (getTargetProgram()->getOptionSet().hasOption(CompilerOptionName::DenormalModeFp64)) + { + switch (getTargetProgram()->getOptionSet().getEnumOption<FloatingPointDenormalMode>( + CompilerOptionName::DenormalModeFp64)) + { + case FloatingPointDenormalMode::Any: + options.denormalModeFp64 = DownstreamCompileOptions::FloatingPointDenormalMode::Any; + break; + case FloatingPointDenormalMode::Preserve: + options.denormalModeFp64 = + DownstreamCompileOptions::FloatingPointDenormalMode::Preserve; + break; + case FloatingPointDenormalMode::FlushToZero: + options.denormalModeFp64 = + DownstreamCompileOptions::FloatingPointDenormalMode::FlushToZero; + break; + default: + SLANG_ASSERT(!"Unhandled fp64 denormal handling mode"); + } + } + + { + // We need to look at the stage of the entry point(s) we are + // being asked to compile, since this will determine the + // "pipeline" that the result should be compiled for (e.g., + // compute vs. ray tracing). + // + // TODO: This logic is kind of messy in that it assumes + // a program to be compiled will only contain kernels for + // a single pipeline type, but that invariant isn't expressed + // at all in the front-end today. It also has no error + // checking for the case where there are conflicts. + // + // HACK: Right now none of the above concerns matter + // because we always perform code generation on a single + // entry point at a time. + // + Index entryPointCount = getEntryPointCount(); + for (Index ee = 0; ee < entryPointCount; ++ee) + { + auto stage = getEntryPoint(ee)->getStage(); + switch (stage) + { + default: + break; + + case Stage::Compute: + options.pipelineType = DownstreamCompileOptions::PipelineType::Compute; + break; + + case Stage::Vertex: + case Stage::Hull: + case Stage::Domain: + case Stage::Geometry: + case Stage::Fragment: + options.pipelineType = DownstreamCompileOptions::PipelineType::Rasterization; + break; + + case Stage::RayGeneration: + case Stage::Intersection: + case Stage::AnyHit: + case Stage::ClosestHit: + case Stage::Miss: + case Stage::Callable: + options.pipelineType = DownstreamCompileOptions::PipelineType::RayTracing; + break; + } + } + } + + // Add all the search paths (as calculated earlier - they will only be set if this is a pass + // through else will be empty) + options.includePaths = allocator.allocate(includePaths); + + // Add the specified defines (as calculated earlier - they will only be set if this is a + // pass through else will be empty) + { + const auto count = preprocessorDefinitions.getCount(); + auto dst = allocator.getArena().allocateArray<DownstreamCompileOptions::Define>(count); + + Index i = 0; + + for (const auto& [defKey, defValue] : preprocessorDefinitions) + { + auto& define = dst[i]; + + define.nameWithSig = allocator.allocate(defKey); + define.value = allocator.allocate(defValue); + + ++i; + } + options.defines = makeSlice(dst, count); + } + + // Add all of the module libraries + libraries.addRange(linkage->m_libModules.getBuffer(), linkage->m_libModules.getCount()); + } + + auto program = getProgram(); + + // Load embedded precompiled libraries from IR into library artifacts + program->enumerateIRModules( + [&](IRModule* irModule) + { + for (auto globalInst : irModule->getModuleInst()->getChildren()) + { + if (target == CodeGenTarget::DXILAssembly || target == CodeGenTarget::DXIL) + { + if (auto inst = as<IREmbeddedDownstreamIR>(globalInst)) + { + if (inst->getTarget() == CodeGenTarget::DXIL) + { + auto slice = inst->getBlob()->getStringSlice(); + ArtifactDesc desc = + ArtifactDescUtil::makeDescForCompileTarget(SLANG_DXIL); + desc.kind = ArtifactKind::Library; + + auto library = ArtifactUtil::createArtifact(desc); + + library->addRepresentationUnknown(StringBlob::create(slice)); + libraries.add(library); + } + } + } + } + }); + + options.compilerSpecificArguments = allocator.allocate(compilerSpecificArguments); + options.requiredCapabilityVersions = SliceUtil::asSlice(requiredCapabilityVersions); + options.libraries = SliceUtil::asSlice(libraries); + options.libraryPaths = allocator.allocate(libraryPaths); + + if (m_targetProfile.getFamily() == ProfileFamily::DX) + { + options.enablePAQ = m_targetProfile.getVersion() >= ProfileVersion::DX_6_7; + } + + // Compile + ComPtr<IArtifact> artifact; + auto downstreamStartTime = std::chrono::high_resolution_clock::now(); + SLANG_RETURN_ON_FAIL(compiler->compile(options, artifact.writeRef())); + auto downstreamElapsedTime = + (std::chrono::high_resolution_clock::now() - downstreamStartTime).count() * 0.000000001; + getSession()->addDownstreamCompileTime(downstreamElapsedTime); + + SLANG_RETURN_ON_FAIL(passthroughDownstreamDiagnostics(getSink(), compiler, artifact)); + + // Copy over all of the information associated with the source into the output + if (sourceArtifact) + { + for (auto associatedArtifact : sourceArtifact->getAssociated()) + { + artifact->addAssociated(associatedArtifact); + } + } + + // Set the artifact + outArtifact.swap(artifact); + return SLANG_OK; +} + +SlangResult emitSPIRVForEntryPointsDirectly( + CodeGenContext* codeGenContext, + ComPtr<IArtifact>& outArtifact); + +SlangResult emitHostVMCode(CodeGenContext* codeGenContext, ComPtr<IArtifact>& outArtifact); + +static CodeGenTarget _getIntermediateTarget(CodeGenTarget target) +{ + switch (target) + { + case CodeGenTarget::DXBytecodeAssembly: + return CodeGenTarget::DXBytecode; + case CodeGenTarget::DXILAssembly: + return CodeGenTarget::DXIL; + case CodeGenTarget::SPIRVAssembly: + return CodeGenTarget::SPIRV; + case CodeGenTarget::WGSLSPIRVAssembly: + return CodeGenTarget::WGSLSPIRV; + default: + return CodeGenTarget::None; + } +} + +IArtifact* getSeparateDbgArtifact(IArtifact* artifact) +{ + if (!artifact) + return nullptr; + + // The first associated artifact of kind ObjectCode and SPIRV payload should be the debug + // artifact. + for (auto* associated : artifact->getAssociated()) + { + auto desc = associated->getDesc(); + if (desc.kind == ArtifactKind::ObjectCode && desc.payload == ArtifactPayload::SPIRV) + return associated; + } + + return nullptr; +} + +/// Function to simplify the logic around emitting, and dissassembling +SlangResult CodeGenContext::_emitEntryPoints(ComPtr<IArtifact>& outArtifact) +{ + auto target = getTargetFormat(); + switch (target) + { + case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::DXBytecodeAssembly: + case CodeGenTarget::DXILAssembly: + case CodeGenTarget::MetalLibAssembly: + case CodeGenTarget::WGSLSPIRVAssembly: + { + // First compile to an intermediate target for the corresponding binary format. + const CodeGenTarget intermediateTarget = _getIntermediateTarget(target); + CodeGenContext intermediateContext(this, intermediateTarget); + + ComPtr<IArtifact> intermediateArtifact; + + SLANG_RETURN_ON_FAIL(intermediateContext._emitEntryPoints(intermediateArtifact)); + intermediateContext.maybeDumpIntermediate(intermediateArtifact); + + // Then disassemble the intermediate binary result to get the desired output + // Output the disassemble + ComPtr<IArtifact> disassemblyArtifact; + SLANG_RETURN_ON_FAIL(ArtifactOutputUtil::dissassembleWithDownstream( + getSession(), + intermediateArtifact, + getSink(), + disassemblyArtifact.writeRef())); + + // Also disassemble the debug artifact if one exists. + auto debugArtifact = getSeparateDbgArtifact(intermediateArtifact); + ComPtr<IArtifact> disassemblyDebugArtifact; + if (debugArtifact) + { + SLANG_RETURN_ON_FAIL(ArtifactOutputUtil::dissassembleWithDownstream( + getSession(), + debugArtifact, + getSink(), + disassemblyDebugArtifact.writeRef())); + disassemblyDebugArtifact->setName(debugArtifact->getName()); + + // The disassembly needs both the metadata for the debug build identifier + // and the debug spirv to be associated with is. + for (auto associated : intermediateArtifact->getAssociated()) + { + if (associated->getDesc().payload == ArtifactPayload::Metadata || + associated->getDesc().payload == ArtifactPayload::PostEmitMetadata) + { + disassemblyArtifact->addAssociated(associated); + break; + } + } + disassemblyArtifact->addAssociated(disassemblyDebugArtifact); + } + + outArtifact.swap(disassemblyArtifact); + return SLANG_OK; + } + case CodeGenTarget::SPIRV: + if (getTargetProgram()->getOptionSet().shouldEmitSPIRVDirectly()) + { + SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPointsDirectly(this, outArtifact)); + return SLANG_OK; + } + [[fallthrough]]; + case CodeGenTarget::DXIL: + case CodeGenTarget::DXBytecode: + case CodeGenTarget::MetalLib: + case CodeGenTarget::PTX: + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::ShaderSharedLibrary: + case CodeGenTarget::HostExecutable: + case CodeGenTarget::HostHostCallable: + case CodeGenTarget::HostSharedLibrary: + case CodeGenTarget::WGSLSPIRV: + SLANG_RETURN_ON_FAIL(emitWithDownstreamForEntryPoints(outArtifact)); + return SLANG_OK; + case CodeGenTarget::HostVM: + SLANG_RETURN_ON_FAIL(emitHostVMCode(this, outArtifact)); + return SLANG_OK; + default: + break; + } + + return SLANG_FAIL; +} + +// Helper class for recording compile time. +struct CompileTimerRAII +{ + std::chrono::high_resolution_clock::time_point startTime; + Session* session; + CompileTimerRAII(Session* inSession) + { + startTime = std::chrono::high_resolution_clock::now(); + session = inSession; + } + ~CompileTimerRAII() + { + double elapsedTime = std::chrono::duration_cast<std::chrono::microseconds>( + std::chrono::high_resolution_clock::now() - startTime) + .count() / + 1e6; + session->addTotalCompileTime(elapsedTime); + } +}; + +// Do emit logic for a zero or more entry points +SlangResult CodeGenContext::emitEntryPoints(ComPtr<IArtifact>& outArtifact) +{ + CompileTimerRAII recordCompileTime(getSession()); + + auto target = getTargetFormat(); + + switch (target) + { + case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::DXBytecodeAssembly: + case CodeGenTarget::DXILAssembly: + case CodeGenTarget::SPIRV: + case CodeGenTarget::DXIL: + case CodeGenTarget::DXBytecode: + case CodeGenTarget::MetalLib: + case CodeGenTarget::MetalLibAssembly: + case CodeGenTarget::PTX: + case CodeGenTarget::HostHostCallable: + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::ShaderSharedLibrary: + case CodeGenTarget::HostExecutable: + case CodeGenTarget::HostSharedLibrary: + case CodeGenTarget::WGSLSPIRVAssembly: + case CodeGenTarget::HostVM: + { + SLANG_RETURN_ON_FAIL(_emitEntryPoints(outArtifact)); + + maybeDumpIntermediate(outArtifact); + return SLANG_OK; + } + break; + case CodeGenTarget::GLSL: + case CodeGenTarget::HLSL: + case CodeGenTarget::CUDASource: + case CodeGenTarget::CPPSource: + case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: + case CodeGenTarget::CSource: + case CodeGenTarget::Metal: + case CodeGenTarget::WGSL: + { + RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target); + + CodeGenContext subContext(this, target, extensionTracker); + + ComPtr<IArtifact> sourceArtifact; + + SLANG_RETURN_ON_FAIL(subContext.emitEntryPointsSource(sourceArtifact)); + + subContext.maybeDumpIntermediate(sourceArtifact); + outArtifact = sourceArtifact; + return SLANG_OK; + } + break; + + case CodeGenTarget::None: + // The user requested no output + return SLANG_OK; + + // Note(tfoley): We currently hit this case when compiling the core module + case CodeGenTarget::Unknown: + return SLANG_OK; + + default: + SLANG_UNEXPECTED("unhandled code generation target"); + break; + } + return SLANG_FAIL; +} + +void CodeGenContext::_dumpIntermediateMaybeWithAssembly(IArtifact* artifact) +{ + _dumpIntermediate(artifact); + + ComPtr<IArtifact> assembly; + ArtifactOutputUtil::maybeDisassemble(getSession(), artifact, nullptr, assembly); + + if (assembly) + { + _dumpIntermediate(assembly); + } +} + +void CodeGenContext::_dumpIntermediate(IArtifact* artifact) +{ + ComPtr<ISlangBlob> blob; + if (SLANG_FAILED(artifact->loadBlob(ArtifactKeep::No, blob.writeRef()))) + { + return; + } + _dumpIntermediate(artifact->getDesc(), blob->getBufferPointer(), blob->getBufferSize()); +} + +void CodeGenContext::_dumpIntermediate(const ArtifactDesc& desc, void const* data, size_t size) +{ + // Try to generate a unique ID for the file to dump, + // even in cases where there might be multiple threads + // doing compilation. + // + // This is primarily a debugging aid, so we don't + // really need/want to do anything too elaborate + + static std::atomic<uint32_t> counter(0); + + const uint32_t id = ++counter; + + // Just use the counter for the 'base name' + StringBuilder basename; + + // Add the prefix + basename << getIntermediateDumpPrefix(); + + // Add the id + basename << int(id); + + // Work out the filename based on the desc and the basename + StringBuilder filename; + ArtifactDescUtil::calcNameForDesc(desc, basename.getUnownedSlice(), filename); + + // If didn't produce a filename, use basename with .unknown extension + if (filename.getLength() == 0) + { + filename = basename; + filename << ".unknown"; + } + + // Write to a file + ArtifactOutputUtil::writeToFile(desc, data, size, filename); +} + +void CodeGenContext::maybeDumpIntermediate(IArtifact* artifact) +{ + if (!shouldDumpIntermediates()) + return; + + + _dumpIntermediateMaybeWithAssembly(artifact); +} + +IRDumpOptions CodeGenContext::getIRDumpOptions() +{ + if (auto endToEndReq = isEndToEndCompile()) + { + return endToEndReq->getFrontEndReq()->m_irDumpOptions; + } + return IRDumpOptions(); +} + +bool CodeGenContext::shouldValidateIR() +{ + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ValidateIr); +} + +bool CodeGenContext::shouldSkipSPIRVValidation() +{ + return getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::SkipSPIRVValidation); +} + +bool CodeGenContext::shouldDumpIR() +{ + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr); +} + +bool CodeGenContext::shouldSkipDownstreamLinking() +{ + return getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::SkipDownstreamLinking); +} + +bool CodeGenContext::shouldReportCheckpointIntermediates() +{ + return getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::ReportCheckpointIntermediates); +} + +bool CodeGenContext::shouldDumpIntermediates() +{ + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIntermediates); +} + +bool CodeGenContext::shouldTrackLiveness() +{ + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::TrackLiveness); +} + +String CodeGenContext::getIntermediateDumpPrefix() +{ + return getTargetProgram()->getOptionSet().getStringOption( + CompilerOptionName::DumpIntermediatePrefix); +} + +bool CodeGenContext::getUseUnknownImageFormatAsDefault() +{ + return getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::DefaultImageFormatUnknown); +} + +bool CodeGenContext::isSpecializationDisabled() +{ + return getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::DisableSpecialization); +} + +SlangResult CodeGenContext::requireTranslationUnitSourceFiles() +{ + if (auto endToEndReq = isPassThroughEnabled()) + { + for (auto entryPointIndex : getEntryPointIndices()) + { + auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndex); + SLANG_ASSERT(translationUnit); + /// Make sure we have the source files + SLANG_RETURN_ON_FAIL(translationUnit->requireSourceFiles()); + } + } + + return SLANG_OK; +} + +} // namespace Slang diff --git a/source/slang/slang-code-gen.h b/source/slang/slang-code-gen.h new file mode 100644 index 000000000..81271abad --- /dev/null +++ b/source/slang/slang-code-gen.h @@ -0,0 +1,263 @@ +// slang-code-gen.h +#pragma once + +// +// This file defines the `CodeGenContext` type and related +// utilities. The `CodeGenContext` is used to bundle together +// the information needed by the back-end of the Slang +// compiler, and to help ensure that the back-end is not able +// to access (and thus rely on) information that should only +// be available to the front-end. Maintaining that split +// ensures that results are consistent between end-to-end +// and seaprate-compilation scenarios. +// + +#include "slang-entry-point.h" +#include "slang-session.h" +#include "slang-target-program.h" + +namespace Slang +{ + +/// A back-end-specific object to track optional feaures/capabilities/extensions +/// that are discovered to be used by a program/kernel as part of code generation. +class ExtensionTracker : public RefObject +{ + // TODO: The existence of this type is evidence of a design/architecture problem. + // + // A better formulation of things requires a few key changes: + // + // 1. All optional capabilities need to be enumerated as part of the `CapabilitySet` + // system, so that they can be reasoned about uniformly across different targets + // and different layers of the compiler. + // + // 2. The front-end should be responsible for either or both of: + // + // * Checking that `public` or otherwise externally-visible items (declarations/definitions) + // explicitly declare the capabilities they require, and that they only ever + // make use of items that are comatible with those required capabilities. + // + // * Inferring the capabilities required by items that are not externally visible, + // and attaching those capabilities explicit as a modifier or other synthesized AST node. + // + // 3. The capabilities required by a given `ComponentType` and its entry points should be + // explicitly know-able, and they should be something we can compare to the capabilities + // of a code generation target *before* back-end code generation is started. We should be + // able to issue error messages around lacking capabilities in a way the user can understand, + // in terms of the high-level-language entities. + +public: +}; + +struct RequiredLoweringPassSet +{ + bool debugInfo; + bool resultType; + bool optionalType; + bool enumType; + bool combinedTextureSamplers; + bool reinterpret; + bool generics; + bool bindExistential; + bool autodiff; + bool derivativePyBindWrapper; + bool bitcast; + bool existentialTypeLayout; + bool bindingQuery; + bool meshOutput; + bool higherOrderFunc; + bool globalVaryingVar; + bool glslSSBO; + bool byteAddressBuffer; + bool dynamicResource; + bool dynamicResourceHeap; + bool resolveVaryingInputRef; + bool specializeStageSwitch; + bool missingReturn; + bool nonVectorCompositeSelect; +}; + +/// A context for code generation in the compiler back-end +struct CodeGenContext +{ +public: + typedef List<Index> EntryPointIndices; + + struct Shared + { + public: + Shared( + TargetProgram* targetProgram, + EntryPointIndices const& entryPointIndices, + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq) + : targetProgram(targetProgram) + , entryPointIndices(entryPointIndices) + , sink(sink) + , endToEndReq(endToEndReq) + { + } + + // Shared( + // TargetProgram* targetProgram, + // EndToEndCompileRequest* endToEndReq); + + TargetProgram* targetProgram = nullptr; + EntryPointIndices entryPointIndices; + DiagnosticSink* sink = nullptr; + EndToEndCompileRequest* endToEndReq = nullptr; + }; + + CodeGenContext(Shared* shared) + : m_shared(shared) + , m_targetFormat(shared->targetProgram->getTargetReq()->getTarget()) + , m_targetProfile(shared->targetProgram->getOptionSet().getProfile()) + { + } + + CodeGenContext( + CodeGenContext* base, + CodeGenTarget targetFormat, + ExtensionTracker* extensionTracker = nullptr) + : m_shared(base->m_shared) + , m_targetFormat(targetFormat) + , m_extensionTracker(extensionTracker) + { + } + + /// Get the diagnostic sink + DiagnosticSink* getSink() { return m_shared->sink; } + + TargetProgram* getTargetProgram() { return m_shared->targetProgram; } + + EntryPointIndices const& getEntryPointIndices() { return m_shared->entryPointIndices; } + + CodeGenTarget getTargetFormat() { return m_targetFormat; } + + ExtensionTracker* getExtensionTracker() { return m_extensionTracker; } + + TargetRequest* getTargetReq() { return getTargetProgram()->getTargetReq(); } + + CapabilitySet getTargetCaps() { return getTargetReq()->getTargetCaps(); } + + CodeGenTarget getFinalTargetFormat() { return getTargetReq()->getTarget(); } + + ComponentType* getProgram() { return getTargetProgram()->getProgram(); } + + Linkage* getLinkage() { return getProgram()->getLinkage(); } + + Session* getSession() { return getLinkage()->getSessionImpl(); } + + /// Get the source manager + SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } + + ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } + + EndToEndCompileRequest* isEndToEndCompile() { return m_shared->endToEndReq; } + + EndToEndCompileRequest* isPassThroughEnabled(); + + Count getEntryPointCount() { return getEntryPointIndices().getCount(); } + + EntryPoint* getEntryPoint(Index index) { return getProgram()->getEntryPoint(index); } + + Index getSingleEntryPointIndex() + { + SLANG_ASSERT(getEntryPointCount() == 1); + return getEntryPointIndices()[0]; + } + + // + + IRDumpOptions getIRDumpOptions(); + + bool shouldValidateIR(); + bool shouldDumpIR(); + bool shouldReportCheckpointIntermediates(); + + bool shouldTrackLiveness(); + + bool shouldDumpIntermediates(); + String getIntermediateDumpPrefix(); + + bool getUseUnknownImageFormatAsDefault(); + + bool isSpecializationDisabled(); + + bool shouldSkipSPIRVValidation(); + + SlangResult requireTranslationUnitSourceFiles(); + + // + + SlangResult emitEntryPoints(ComPtr<IArtifact>& outArtifact); + + SlangResult emitPrecompiledDownstreamIR(ComPtr<IArtifact>& outArtifact); + + void maybeDumpIntermediate(IArtifact* artifact); + + // Used to cause instructions available in precompiled blobs to be + // removed between IR linking and target source generation. + bool removeAvailableInDownstreamIR = false; + + // Determines if program level compilation like getTargetCode() or getEntryPointCode() + // should return a fully linked downstream program or just the glue SPIR-V/DXIL that + // imports and uses the precompiled SPIR-V/DXIL from constituent modules. + // This is a no-op if modules are not precompiled. + bool shouldSkipDownstreamLinking(); + + RequiredLoweringPassSet& getRequiredLoweringPassSet() { return m_requiredLoweringPassSet; } + +protected: + CodeGenTarget m_targetFormat = CodeGenTarget::Unknown; + Profile m_targetProfile; + ExtensionTracker* m_extensionTracker = nullptr; + + // To improve the performance of our backend, we will try to avoid running + // passes related to features not used in the user code. + // To do so, we will scan the IR module once, and determine which passes are needed + // based on the instructions used in the IR module. + // This will allow us to skip running passes that are not needed, without having to + // run all the passes only to find out that no work is needed. + // This is especially important for the performance of the backend, as some passes + // have an initialization cost (such as building reference graphs or DOM trees) that + // can be expensive. + RequiredLoweringPassSet m_requiredLoweringPassSet; + + /// Will output assembly as well as the artifact if appropriate for the artifact type for + /// assembly output and conversion is possible + void _dumpIntermediateMaybeWithAssembly(IArtifact* artifact); + + void _dumpIntermediate(IArtifact* artifact); + void _dumpIntermediate(const ArtifactDesc& desc, void const* data, size_t size); + + /* Emits entry point source taking into account if a pass-through or not. Uses 'targetFormat' to + determine the target (not targetReq) */ + SlangResult emitEntryPointsSource(ComPtr<IArtifact>& outArtifact); + + SlangResult emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outArtifact); + + SlangResult emitWithDownstreamForEntryPoints(ComPtr<IArtifact>& outArtifact); + + /* Determines a suitable filename to identify the input for a given entry point being compiled. + If the end-to-end compile is a pass-through case, will attempt to find the (unique) source file + pathname for the translation unit containing the entry point at `entryPointIndex. + If the compilation is not in a pass-through case, then always returns `"slang-generated"`. + @param endToEndReq The end-to-end compile request which might be using pass-through compilation + @param entryPointIndex The index of the entry point to compute a filename for. + @return the appropriate source filename */ + String calcSourcePathForEntryPoints(); + + TranslationUnitRequest* findPassThroughTranslationUnit(Int entryPointIndex); + + + SlangResult _emitEntryPoints(ComPtr<IArtifact>& outArtifact); + +private: + Shared* m_shared = nullptr; +}; + +// TODO: The "artifact" system is a scourge. +IArtifact* getSeparateDbgArtifact(IArtifact* artifact); + +} // namespace Slang diff --git a/source/slang/slang-compile-request.cpp b/source/slang/slang-compile-request.cpp new file mode 100644 index 000000000..6cbf79f96 --- /dev/null +++ b/source/slang/slang-compile-request.cpp @@ -0,0 +1,703 @@ +// slang-compile-request.cpp +#include "slang-compile-request.h" + +#include "../core/slang-performance-profiler.h" +#include "compiler-core/slang-artifact-desc-util.h" +#include "compiler-core/slang-artifact-util.h" +#include "slang-ast-dump.h" +#include "slang-check-impl.h" +#include "slang-compiler.h" +#include "slang-emit-source-writer.h" +#include "slang-lower-to-ir.h" +#include "slang-parser.h" +#include "slang-serialize-container.h" + +namespace Slang +{ +// +// FrontEndEntryPointRequest +// + +FrontEndEntryPointRequest::FrontEndEntryPointRequest( + FrontEndCompileRequest* compileRequest, + int translationUnitIndex, + Name* name, + Profile profile) + : m_compileRequest(compileRequest) + , m_translationUnitIndex(translationUnitIndex) + , m_name(name) + , m_profile(profile) +{ +} + +TranslationUnitRequest* FrontEndEntryPointRequest::getTranslationUnit() +{ + return getCompileRequest()->translationUnits[m_translationUnitIndex]; +} + +// +// CompileRequestBase +// + +CompileRequestBase::CompileRequestBase(Linkage* linkage, DiagnosticSink* sink) + : m_linkage(linkage), m_sink(sink) +{ +} + +Session* CompileRequestBase::getSession() +{ + return getLinkage()->getSessionImpl(); +} + +// +// FrontEndCompileRequest +// + +FrontEndCompileRequest::FrontEndCompileRequest( + Linkage* linkage, + StdWriters* writers, + DiagnosticSink* sink) + : CompileRequestBase(linkage, sink), m_writers(writers) +{ + optionSet.inheritFrom(linkage->m_optionSet); +} + +// Holds the hierarchy of views, the children being views that were 'initiated' (have an initiating +// SourceLoc) in the parent. +typedef Dictionary<SourceView*, List<SourceView*>> ViewInitiatingHierarchy; + +// Calculate the hierarchy from the sourceManager +static void _calcViewInitiatingHierarchy( + SourceManager* sourceManager, + ViewInitiatingHierarchy& outHierarchy) +{ + const List<SourceView*> emptyList; + outHierarchy.clear(); + + // Iterate over all managers + for (SourceManager* curManager = sourceManager; curManager; + curManager = curManager->getParent()) + { + // Iterate over all views + for (SourceView* view : curManager->getSourceViews()) + { + if (view->getInitiatingSourceLoc().isValid()) + { + // Look up the view it came from + SourceView* parentView = + sourceManager->findSourceViewRecursively(view->getInitiatingSourceLoc()); + if (parentView) + { + List<SourceView*>& children = outHierarchy.getOrAddValue(parentView, emptyList); + // It shouldn't have already been added + SLANG_ASSERT(children.indexOf(view) < 0); + children.add(view); + } + } + } + } + + // Order all the children, by their raw SourceLocs. This is desirable, so that a trivial + // traversal will traverse children in the order they are initiated in the parent source. This + // assumes they increase in SourceLoc implies an later within a source file - this is true + // currently. + for (auto& [_, value] : outHierarchy) + { + value.sort( + [](SourceView* a, SourceView* b) -> bool { + return a->getInitiatingSourceLoc().getRaw() < b->getInitiatingSourceLoc().getRaw(); + }); + } +} + +// Given a source file, find the view that is the initial SourceView use of the source. It must have +// an initiating SourceLoc that is not valid. +static SourceView* _findInitialSourceView(SourceFile* sourceFile) +{ + // TODO(JS): + // This might be overkill - presumably the SourceView would belong to the same manager as it's + // SourceFile? That is not enforced by the SourceManager in any way though so we just search all + // managers, and all views. + for (SourceManager* sourceManager = sourceFile->getSourceManager(); sourceManager; + sourceManager = sourceManager->getParent()) + { + for (SourceView* view : sourceManager->getSourceViews()) + { + if (view->getSourceFile() == sourceFile && !view->getInitiatingSourceLoc().isValid()) + { + return view; + } + } + } + + return nullptr; +} + +static void _outputInclude(SourceFile* sourceFile, Index depth, DiagnosticSink* sink) +{ + StringBuilder buf; + + for (Index i = 0; i < depth; ++i) + { + buf << " "; + } + + // Output the found path for now + // TODO(JS). We could use the verbose paths flag to control what path is output -> as it may be + // useful to output the full path for example + + const PathInfo& pathInfo = sourceFile->getPathInfo(); + buf << "'" << pathInfo.foundPath << "'"; + + // TODO(JS)? + // You might want to know where this include was from. + // If I output this though there will be a problem... as the indenting won't be clearly shown. + // Perhaps I output in two sections, one the hierarchy and the other the locations of the + // includes? + + sink->diagnose(SourceLoc(), Diagnostics::includeOutput, buf); +} + +static void _outputIncludesRec( + SourceView* sourceView, + Index depth, + ViewInitiatingHierarchy& hierarchy, + DiagnosticSink* sink) +{ + SourceFile* sourceFile = sourceView->getSourceFile(); + const PathInfo& pathInfo = sourceFile->getPathInfo(); + + switch (pathInfo.type) + { + case PathInfo::Type::TokenPaste: + case PathInfo::Type::CommandLine: + case PathInfo::Type::TypeParse: + { + // If any of these types we don't output + return; + } + default: + break; + } + + // Okay output this file at the current depth + _outputInclude(sourceFile, depth, sink); + + // Now recurse to all of the children at the next depth + List<SourceView*>* children = hierarchy.tryGetValue(sourceView); + if (children) + { + for (SourceView* child : *children) + { + _outputIncludesRec(child, depth + 1, hierarchy, sink); + } + } +} + +static void _outputPreprocessorTokens(const TokenList& toks, ISlangWriter* writer) +{ + if (writer == nullptr) + { + return; + } + + StringBuilder buf; + for (const auto& tok : toks) + { + buf << tok.getContent(); + // We'll separate tokens with space for now + buf.appendChar(' '); + } + + buf.appendChar('\n'); + + writer->write(buf.getBuffer(), buf.getLength()); +} + +static void _outputIncludes( + const List<SourceFile*>& sourceFiles, + SourceManager* sourceManager, + DiagnosticSink* sink) +{ + // Set up the hierarchy to know how all the source views relate. This could be argued as + // overkill, but makes recursive output pretty simple + ViewInitiatingHierarchy hierarchy; + _calcViewInitiatingHierarchy(sourceManager, hierarchy); + + // For all the source files + for (SourceFile* sourceFile : sourceFiles) + { + if (sourceFile->isIncludedFile()) + continue; + + // Find an initial view (this is the view of this file, that doesn't have an initiating loc) + SourceView* sourceView = _findInitialSourceView(sourceFile); + if (!sourceView) + { + // Okay, didn't find one, so just output the file + _outputInclude(sourceFile, 0, sink); + } + else + { + // Output from this view recursively + _outputIncludesRec(sourceView, 0, hierarchy, sink); + } + } +} + +void FrontEndCompileRequest::parseTranslationUnit(TranslationUnitRequest* translationUnit) +{ + SLANG_PROFILE; + if (translationUnit->isChecked) + return; + + auto linkage = getLinkage(); + + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + + // TODO(JS): NOTE! Here we are using the searchDirectories on the linkage. This is because + // currently the API only allows the setting search paths on linkage. + // + // Here we should probably be using the searchDirectories on the FrontEndCompileRequest. + // If searchDirectories.parent pointed to the one in the Linkage would mean linkage paths + // would be checked too (after those on the FrontEndCompileRequest). + IncludeSystem includeSystem( + &linkage->getSearchDirectories(), + linkage->getFileSystemExt(), + linkage->getSourceManager()); + + auto combinedPreprocessorDefinitions = translationUnit->getCombinedPreprocessorDefinitions(); + + auto module = translationUnit->getModule(); + + ASTBuilder* astBuilder = module->getASTBuilder(); + + ModuleDecl* translationUnitSyntax = astBuilder->create<ModuleDecl>(); + + translationUnitSyntax->nameAndLoc.name = translationUnit->moduleName; + translationUnitSyntax->module = module; + module->setModuleDecl(translationUnitSyntax); + + // When compiling a module of code that belongs to the Slang + // core module, we add a modifier to the module to act + // as a marker, so that downstream code can detect declarations + // that came from the core module (by walking up their + // chain of ancestors and looking for the marker), and treat + // them differently from user declarations. + // + // We are adding the marker here, before we even parse the + // code in the module, in case the subsequent steps would + // like to treat the core module differently. Alternatively + // we could pass down the `m_isStandardLibraryCode` flag to + // these passes. + // + if (m_isCoreModuleCode) + { + translationUnitSyntax->modifiers.first = astBuilder->create<FromCoreModuleModifier>(); + } + + // We use a custom handler for preprocessor callbacks, to + // ensure that relevant state that is only visible during + // preprocessoing can be communicated to later phases of + // compilation. + // + FrontEndPreprocessorHandler preprocessorHandler(module, astBuilder, getSink(), translationUnit); + + for (auto sourceFile : translationUnit->getSourceFiles()) + { + module->getIncludedSourceFileMap().addIfNotExists(sourceFile, nullptr); + } + + for (auto sourceFile : translationUnit->getSourceFiles()) + { + SourceLanguage sourceLanguage = translationUnit->sourceLanguage; + SlangLanguageVersion languageVersion = + translationUnit->compileRequest->optionSet.getLanguageVersion(); + auto tokens = preprocessSource( + sourceFile, + getSink(), + &includeSystem, + combinedPreprocessorDefinitions, + getLinkage(), + sourceLanguage, + languageVersion, + &preprocessorHandler); + + translationUnitSyntax->languageVersion = languageVersion; + + if (sourceLanguage == SourceLanguage::Unknown) + sourceLanguage = translationUnit->sourceLanguage; + + Scope* languageScope = nullptr; + switch (sourceLanguage) + { + case SourceLanguage::HLSL: + languageScope = getSession()->hlslLanguageScope; + break; + case SourceLanguage::GLSL: + languageScope = getSession()->glslLanguageScope; + break; + case SourceLanguage::Slang: + default: + languageScope = getSession()->slangLanguageScope; + break; + } + + if (optionSet.getBoolOption(CompilerOptionName::OutputIncludes)) + { + _outputIncludes( + translationUnit->getSourceFiles(), + getSink()->getSourceManager(), + getSink()); + } + + if (optionSet.getBoolOption(CompilerOptionName::PreprocessorOutput)) + { + if (m_writers) + { + _outputPreprocessorTokens( + tokens, + m_writers->getWriter(SLANG_WRITER_CHANNEL_STD_OUTPUT)); + } + // If we output the preprocessor output then we are done doing anything else + return; + } + + parseSourceFile( + astBuilder, + translationUnit, + sourceLanguage, + tokens, + getSink(), + languageScope, + translationUnitSyntax); + + // Let's try dumping + + if (optionSet.getBoolOption(CompilerOptionName::DumpAst)) + { + StringBuilder buf; + SourceWriter writer(linkage->getSourceManager(), LineDirectiveMode::None, nullptr); + + ASTDumpUtil::dump( + translationUnit->getModuleDecl(), + ASTDumpUtil::Style::Flat, + 0, + &writer); + + const String& path = sourceFile->getPathInfo().foundPath; + if (path.getLength()) + { + String fileName = Path::getFileNameWithoutExt(path); + fileName.append(".slang-ast"); + + File::writeAllText(fileName, writer.getContent()); + } + } + +#if 0 + // Test serialization + { + ASTSerialTestUtil::testSerialize(translationUnit->getModuleDecl(), getSession()->getNamePool(), getLinkage()->getASTBuilder()->getSharedASTBuilder(), getSourceManager()); + } +#endif + } +} + +void FrontEndCompileRequest::checkAllTranslationUnits() +{ + SLANG_PROFILE; + + LoadedModuleDictionary loadedModules; + if (additionalLoadedModules) + loadedModules = *additionalLoadedModules; + + // Iterate over all translation units and + // apply the semantic checking logic. + for (auto& translationUnit : translationUnits) + { + if (translationUnit->isChecked) + continue; + + checkTranslationUnit(translationUnit.Ptr(), loadedModules); + + // Add the checked module to list of loadedModules so that they can be + // discovered by `findOrImportModule` when processing future `import` decls. + // TODO: this does not handle the case where a translation unit to discover + // another translation unit added later to the compilation request. + // We should output an error message when we detect such a case, or support + // this scenario with a recursive style checking. + loadedModules.add(translationUnit->moduleName, translationUnit->getModule()); + } + checkEntryPoints(); +} + +void FrontEndCompileRequest::generateIR() +{ + SLANG_PROFILE; + SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder()); + + // Our task in this function is to generate IR code + // for all of the declarations in the translation + // units that were loaded. + + // Each translation unit is its own little world + // for code generation (we are not trying to + // replicate the GLSL linkage model), and so + // we will generate IR for each (if needed) + // in isolation. + for (auto& translationUnit : translationUnits) + { + // Skip if the module is precompiled. + if (translationUnit->getModule()->getIRModule()) + continue; + + // We want to only run generateIRForTranslationUnit once here. This is for two side effects: + // * it can dump ir + // * it can generate diagnostics + + /// Generate IR for translation unit. + RefPtr<IRModule> irModule( + generateIRForTranslationUnit(getLinkage()->getASTBuilder(), translationUnit)); + + if (verifyDebugSerialization) + { + SerialContainerUtil::WriteOptions options; + + options.sourceManagerToUseWhenSerializingSourceLocs = getSourceManager(); + + // Verify debug information + if (SLANG_FAILED( + SerialContainerUtil::verifyIRSerialize(irModule, getSession(), options))) + { + getSink()->diagnose( + irModule->getModuleInst()->sourceLoc, + Diagnostics::serialDebugVerificationFailed); + } + } + + // Set the module on the translation unit + translationUnit->getModule()->setIRModule(irModule); + } +} + +SlangResult FrontEndCompileRequest::executeActionsInner() +{ + SLANG_PROFILE_SECTION(frontEndExecute); + SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder()); + + for (TranslationUnitRequest* translationUnit : translationUnits) + { + // Make sure SourceFile representation is available for all translationUnits + SLANG_RETURN_ON_FAIL(translationUnit->requireSourceFiles()); + } + + + // Parse everything from the input files requested + for (TranslationUnitRequest* translationUnit : translationUnits) + { + parseTranslationUnit(translationUnit); + } + + if (optionSet.getBoolOption(CompilerOptionName::PreprocessorOutput)) + { + // If doing pre-processor output, then we are done + return SLANG_OK; + } + + if (getSink()->getErrorCount() != 0) + return SLANG_FAIL; + + // Perform semantic checking on the whole collection + { + SLANG_PROFILE_SECTION(SemanticChecking); + checkAllTranslationUnits(); + } + + if (getSink()->getErrorCount() != 0) + return SLANG_FAIL; + + // After semantic checking is performed we can try and output doc information for this + if (optionSet.getBoolOption(CompilerOptionName::Doc)) + { + // TODO: implement the logic to output generated documents to target directory/zip file. + } + + // Look up all the entry points that are expected, + // and use them to populate the `program` member. + // + m_globalComponentType = createUnspecializedGlobalComponentType(this); + if (getSink()->getErrorCount() != 0) + return SLANG_FAIL; + + m_globalAndEntryPointsComponentType = + createUnspecializedGlobalAndEntryPointsComponentType(this, m_unspecializedEntryPoints); + if (getSink()->getErrorCount() != 0) + return SLANG_FAIL; + + // We always generate IR for all the translation units. + // + // TODO: We may eventually have a mode where we skip + // IR codegen and only produce an AST (e.g., for use when + // debugging problems in the parser or semantic checking), + // but for now there are no cases where not having IR + // makes sense. + // + generateIR(); + if (getSink()->getErrorCount() != 0) + return SLANG_FAIL; + + // Do parameter binding generation, for each compilation target. + // + for (auto targetReq : getLinkage()->targets) + { + auto targetProgram = m_globalAndEntryPointsComponentType->getTargetProgram(targetReq); + targetProgram->getOrCreateLayout(getSink()); + targetProgram->getOrCreateIRModuleForLayout(getSink()); + } + if (getSink()->getErrorCount() != 0) + return SLANG_FAIL; + + return SLANG_OK; +} + +int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* moduleName) +{ + RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(this); + translationUnit->compileRequest = this; + translationUnit->sourceLanguage = SourceLanguage(language); + + translationUnit->setModuleName(moduleName); + return addTranslationUnit(translationUnit); +} + +int FrontEndCompileRequest::addTranslationUnit(TranslationUnitRequest* translationUnit) +{ + Index result = translationUnits.getCount(); + translationUnits.add(translationUnit); + return (int)result; +} + +void FrontEndCompileRequest::addTranslationUnitSourceArtifact( + int translationUnitIndex, + IArtifact* sourceArtifact) +{ + auto translationUnit = translationUnits[translationUnitIndex]; + + // Add the source file + translationUnit->addSourceArtifact(sourceArtifact); + + if (!translationUnit->moduleName) + { + translationUnit->setModuleName( + getNamePool()->getName(Path::getFileNameWithoutExt(sourceArtifact->getName()))); + } + if (translationUnit->module->getFilePath() == nullptr) + translationUnit->module->setPathInfo(PathInfo::makePath(sourceArtifact->getName())); +} + +void FrontEndCompileRequest::addTranslationUnitSourceBlob( + int translationUnitIndex, + String const& path, + ISlangBlob* sourceBlob) +{ + auto translationUnit = translationUnits[translationUnitIndex]; + auto sourceDesc = + ArtifactDescUtil::makeDescForSourceLanguage(asExternal(translationUnit->sourceLanguage)); + + auto artifact = ArtifactUtil::createArtifact(sourceDesc, path.getBuffer()); + artifact->addRepresentationUnknown(sourceBlob); + + addTranslationUnitSourceArtifact(translationUnitIndex, artifact); +} + +void FrontEndCompileRequest::addTranslationUnitSourceFile( + int translationUnitIndex, + String const& path) +{ + // TODO: We need to consider whether a relative `path` should cause + // us to look things up using the registered search paths. + // + // This behavior wouldn't make sense for command-line invocations + // of `slangc`, but at least one API user wondered by the search + // paths were not taken into account by this function. + // + + auto fileSystemExt = getLinkage()->getFileSystemExt(); + auto translationUnit = getTranslationUnit(translationUnitIndex); + + auto sourceDesc = + ArtifactDescUtil::makeDescForSourceLanguage(asExternal(translationUnit->sourceLanguage)); + + auto sourceArtifact = ArtifactUtil::createArtifact(sourceDesc, path.getBuffer()); + + auto extRep = new ExtFileArtifactRepresentation(path.getUnownedSlice(), fileSystemExt); + sourceArtifact->addRepresentation(extRep); + + SlangResult existsRes = SLANG_OK; + + // If we require caching, we demand it's loaded here. + // + // In practice this probably means repro capture is enabled. So we want to + // load the blob such that it's in the cache, even if it doesn't actually + // have to be loaded for the compilation. + if (getLinkage()->m_requireCacheFileSystem) + { + ComPtr<ISlangBlob> blob; + // If we can load the blob, then it exists + existsRes = sourceArtifact->loadBlob(ArtifactKeep::Yes, blob.writeRef()); + } + else + { + existsRes = sourceArtifact->exists() ? SLANG_OK : SLANG_E_NOT_FOUND; + } + + if (SLANG_FAILED(existsRes)) + { + // Emit a diagnostic! + getSink()->diagnose(SourceLoc(), Diagnostics::cannotOpenFile, path); + return; + } + + addTranslationUnitSourceArtifact(translationUnitIndex, sourceArtifact); +} + +int FrontEndCompileRequest::addEntryPoint( + int translationUnitIndex, + String const& name, + Profile entryPointProfile) +{ + auto translationUnitReq = translationUnits[translationUnitIndex]; + + Index result = m_entryPointReqs.getCount(); + + RefPtr<FrontEndEntryPointRequest> entryPointReq = new FrontEndEntryPointRequest( + this, + translationUnitIndex, + getNamePool()->getName(name), + entryPointProfile); + + m_entryPointReqs.add(entryPointReq); + // translationUnitReq->entryPoints.add(entryPointReq); + + return int(result); +} + +int EndToEndCompileRequest::addEntryPoint( + int translationUnitIndex, + String const& name, + Profile entryPointProfile, + List<String> const& genericTypeNames) +{ + getFrontEndReq()->addEntryPoint(translationUnitIndex, name, entryPointProfile); + + EntryPointInfo entryPointInfo; + for (auto typeName : genericTypeNames) + entryPointInfo.specializationArgStrings.add(typeName); + + Index result = m_entryPoints.getCount(); + m_entryPoints.add(_Move(entryPointInfo)); + return (int)result; +} + +} // namespace Slang diff --git a/source/slang/slang-compile-request.h b/source/slang/slang-compile-request.h new file mode 100644 index 000000000..4dee026aa --- /dev/null +++ b/source/slang/slang-compile-request.h @@ -0,0 +1,362 @@ +// slang-compile-request.h +#pragma once + +// +// This file contains the `FrontEndCompileRequest` type +// and the types that it is built from (such as +// `TranslationUnitRequest`). These types are used +// whenever the Slang front-end is invoked to compile +// a module (or, in some cases, one or more modules) +// from source code to a checked AST and Slang IR. +// +// Note that the `EndToEndCompileRequest` type has its +// own header: `slang-end-to-end-request.h`. +// + +#include "../compiler-core/slang-artifact.h" +#include "../compiler-core/slang-source-loc.h" +#include "../core/slang-smart-pointer.h" +#include "../core/slang-std-writers.h" +#include "slang-compiler-fwd.h" +#include "slang-diagnostics.h" +#include "slang-module.h" +#include "slang-preprocessor.h" +#include "slang-profile.h" +#include "slang-session.h" +#include "slang-translation-unit.h" + +namespace Slang +{ + +/// Shared functionality between front- and back-end compile requests. +/// +/// This is the base class for both `FrontEndCompileRequest` and +/// `BackEndCompileRequest`, and allows a small number of parts of +/// the compiler to be easily invocable from either front-end or +/// back-end work. +/// +class CompileRequestBase : public RefObject +{ + // TODO: We really shouldn't need this type in the long run. + // The few places that rely on it should be refactored to just + // depend on the underlying information (a linkage and a diagnostic + // sink) directly. + // + // The flags to control dumping and validation of IR should be + // moved to some kind of shared settings/options `struct` that + // both front-end and back-end requests can store. + +public: + Session* getSession(); + Linkage* getLinkage() { return m_linkage; } + DiagnosticSink* getSink() { return m_sink; } + SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } + SlangResult loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob) + { + return getLinkage()->loadFile(path, outPathInfo, outBlob); + } + +protected: + CompileRequestBase(Linkage* linkage, DiagnosticSink* sink); + +private: + Linkage* m_linkage = nullptr; + DiagnosticSink* m_sink = nullptr; +}; + +/// A request for the front-end to find and validate an entry-point function +struct FrontEndEntryPointRequest : RefObject +{ +public: + /// Create a request for an entry point. + FrontEndEntryPointRequest( + FrontEndCompileRequest* compileRequest, + int translationUnitIndex, + Name* name, + Profile profile); + + /// Get the parent front-end compile request. + FrontEndCompileRequest* getCompileRequest() { return m_compileRequest; } + + /// Get the translation unit that contains the entry point. + TranslationUnitRequest* getTranslationUnit(); + + /// Get the name of the entry point to find. + Name* getName() { return m_name; } + + /// Get the stage that the entry point is to be compiled for + Stage getStage() { return m_profile.getStage(); } + + /// Get the profile that the entry point is to be compiled for + Profile getProfile() { return m_profile; } + + /// Get the index to the translation unit + int getTranslationUnitIndex() const { return m_translationUnitIndex; } + +private: + // The parent compile request + FrontEndCompileRequest* m_compileRequest; + + // The index of the translation unit that will hold the entry point + int m_translationUnitIndex; + + // The name of the entry point function to look for + Name* m_name; + + // The profile to compile for (including stage) + Profile m_profile; +}; + +/// A request to compile source code to an AST + IR. +class FrontEndCompileRequest : public CompileRequestBase +{ +public: + /// Note that writers can be parsed as nullptr to disable output, + /// and individual channels set to null to disable them + FrontEndCompileRequest(Linkage* linkage, StdWriters* writers, DiagnosticSink* sink); + + int addEntryPoint(int translationUnitIndex, String const& name, Profile entryPointProfile); + + // Translation units we are being asked to compile + List<RefPtr<TranslationUnitRequest>> translationUnits; + + // Additional modules that needs to be made visible to `import` while checking. + const LoadedModuleDictionary* additionalLoadedModules = nullptr; + + RefPtr<TranslationUnitRequest> getTranslationUnit(UInt index) + { + return translationUnits[index]; + } + + // If true will serialize and de-serialize with debug information + bool verifyDebugSerialization = false; + + CompilerOptionSet optionSet; + + List<RefPtr<FrontEndEntryPointRequest>> m_entryPointReqs; + + List<RefPtr<FrontEndEntryPointRequest>> const& getEntryPointReqs() { return m_entryPointReqs; } + UInt getEntryPointReqCount() { return m_entryPointReqs.getCount(); } + FrontEndEntryPointRequest* getEntryPointReq(UInt index) { return m_entryPointReqs[index]; } + + void parseTranslationUnit(TranslationUnitRequest* translationUnit); + + // Perform primary semantic checking on all + // of the translation units in the program + void checkAllTranslationUnits(); + + void checkEntryPoints(); + + void generateIR(); + + SlangResult executeActionsInner(); + + /// Add a translation unit to be compiled. + /// + /// @param language The source language that the translation unit will use (e.g., + /// `SourceLanguage::Slang` + /// @param moduleName The name that will be used for the module compile from the translation + /// unit. + /// + /// If moduleName is passed as nullptr a module name is generated. + /// If all translation units in a compile request use automatically generated + /// module names, then they are guaranteed not to conflict with one another. + /// + /// @return The zero-based index of the translation unit in this compile request. + int addTranslationUnit(SourceLanguage language, Name* moduleName); + + int addTranslationUnit(TranslationUnitRequest* translationUnit); + + void addTranslationUnitSourceArtifact(int translationUnitIndex, IArtifact* sourceArtifact); + + void addTranslationUnitSourceBlob( + int translationUnitIndex, + String const& path, + ISlangBlob* sourceBlob); + + void addTranslationUnitSourceFile(int translationUnitIndex, String const& path); + + /// Get a component type that represents the global scope of the compile request. + ComponentType* getGlobalComponentType() { return m_globalComponentType; } + + /// Get a component type that represents the global scope of the compile request, plus the + /// requested entry points. + ComponentType* getGlobalAndEntryPointsComponentType() + { + return m_globalAndEntryPointsComponentType; + } + + List<RefPtr<ComponentType>> const& getUnspecializedEntryPoints() + { + return m_unspecializedEntryPoints; + } + + /// Does the code we are compiling represent part of the Slang core module? + bool m_isCoreModuleCode = false; + + Name* m_defaultModuleName = nullptr; + + /// The irDumpOptions + IRDumpOptions m_irDumpOptions; + + /// An "extra" entry point that was added via a library reference + struct ExtraEntryPointInfo + { + Name* name; + Profile profile; + String mangledName; + }; + + /// A list of "extra" entry points added via a library reference + List<ExtraEntryPointInfo> m_extraEntryPoints; + +private: + /// A component type that includes only the global scopes of the translation unit(s) that were + /// compiled. + RefPtr<ComponentType> m_globalComponentType; + + /// A component type that extends the global scopes with all of the entry points that were + /// specified. + RefPtr<ComponentType> m_globalAndEntryPointsComponentType; + + List<RefPtr<ComponentType>> m_unspecializedEntryPoints; + + RefPtr<StdWriters> m_writers; +}; + +/// Handlers for preprocessor callbacks to use when doing ordinary front-end compilation +struct FrontEndPreprocessorHandler : PreprocessorHandler +{ +public: + FrontEndPreprocessorHandler( + Module* module, + ASTBuilder* astBuilder, + DiagnosticSink* sink, + TranslationUnitRequest* translationUnit) + : m_module(module) + , m_astBuilder(astBuilder) + , m_sink(sink) + , m_translationUnit(translationUnit) + { + } + +protected: + Module* m_module; + ASTBuilder* m_astBuilder; + DiagnosticSink* m_sink; + TranslationUnitRequest* m_translationUnit = nullptr; + + // The first task that this handler tries to deal with is + // capturing all the files on which a module is dependent. + // + // That information is exposed through public APIs and used + // by applications to decide when they need to "hot reload" + // their shader code. + // + void handleFileDependency(SourceFile* sourceFile) SLANG_OVERRIDE + { + m_module->addFileDependency(sourceFile); + m_translationUnit->addIncludedSourceFileIfNotExist(sourceFile); + } + + // The second task that this handler deals with is detecting + // whether any macro values were set in a given source file + // that are semantically relevant to other stages of compilation. + // + void handleEndOfTranslationUnit(Preprocessor* preprocessor) SLANG_OVERRIDE + { + // We look at the preprocessor state after reading the entire + // source file/string, in order to see if any macros have been + // set that should be considered semantically relevant for + // later stages of compilation. + // + // Note: Checking the macro environment *after* preprocessing is complete + // means that we can treat macros introduced via `-D` options or the API + // equivalently to macros introduced via `#define`s in user code. + // + // For now, the only case of semantically-relevant macros we need to worrry + // about are the NVAPI macros used to establish the register/space to use. + // + static const char* kNVAPIRegisterMacroName = "NV_SHADER_EXTN_SLOT"; + static const char* kNVAPISpaceMacroName = "NV_SHADER_EXTN_REGISTER_SPACE"; + + // For NVAPI use, the `NV_SHADER_EXTN_SLOT` macro is required to be defined. + // + String nvapiRegister; + SourceLoc nvapiRegisterLoc; + if (!SLANG_FAILED(findMacroValue( + preprocessor, + kNVAPIRegisterMacroName, + nvapiRegister, + nvapiRegisterLoc))) + { + // In contrast, NVAPI can be used without defining `NV_SHADER_EXTN_REGISTER_SPACE`, + // which effectively defaults to `space0`. + // + String nvapiSpace = "space0"; + SourceLoc nvapiSpaceLoc; + findMacroValue(preprocessor, kNVAPISpaceMacroName, nvapiSpace, nvapiSpaceLoc); + + // We are going to store the values of these macros on the AST-level `ModuleDecl` + // so that they will be available to later processing stages. + // + auto moduleDecl = m_module->getModuleDecl(); + + if (auto existingModifier = moduleDecl->findModifier<NVAPISlotModifier>()) + { + // If there is already a modifier attached to the module (perhaps + // because of preprocessing a different source file, or because + // of settings established via command-line options), then we + // need to validate that the values being set in this file + // match those already set (or else there is likely to be + // some kind of error in the user's code). + // + _validateNVAPIMacroMatch( + kNVAPIRegisterMacroName, + existingModifier->registerName, + nvapiRegister, + nvapiRegisterLoc); + _validateNVAPIMacroMatch( + kNVAPISpaceMacroName, + existingModifier->spaceName, + nvapiSpace, + nvapiSpaceLoc); + } + else + { + // If there is no existing modifier on the module, then we + // take responsibility for adding one, based on the macro + // values we saw. + // + auto modifier = m_astBuilder->create<NVAPISlotModifier>(); + modifier->loc = nvapiRegisterLoc; + modifier->registerName = nvapiRegister; + modifier->spaceName = nvapiSpace; + + addModifier(moduleDecl, modifier); + } + } + } + + /// Validate that a re-defintion of an NVAPI-related macro matches any previous definition + void _validateNVAPIMacroMatch( + char const* macroName, + String const& existingValue, + String const& newValue, + SourceLoc loc) + { + if (existingValue != newValue) + { + m_sink->diagnose( + loc, + Diagnostics::nvapiMacroMismatch, + macroName, + existingValue, + newValue); + } + } +}; + +} // namespace Slang diff --git a/source/slang/slang-compiler-api.cpp b/source/slang/slang-compiler-api.cpp new file mode 100644 index 000000000..6e3e8615f --- /dev/null +++ b/source/slang/slang-compiler-api.cpp @@ -0,0 +1,2 @@ +// slang-compiler-api.cpp +#include "slang-compiler-api.h" diff --git a/source/slang/slang-compiler-api.h b/source/slang/slang-compiler-api.h new file mode 100644 index 000000000..cfcf65a7f --- /dev/null +++ b/source/slang/slang-compiler-api.h @@ -0,0 +1,124 @@ +// slang-compiler-api.h +#pragma once + +// +// This file provides utilities that are needed at the boundary +// between the public Slang API (the interfaces declared in +// `slang.h`) and the code that implements that API. +// + +#include "slang-end-to-end-request.h" +#include "slang-global-session.h" +#include "slang-module.h" +#include "slang-session.h" + +#include <slang.h> + +namespace Slang +{ + +// +// The following functions are utilties to convert between +// matching "external" (public API) and "internal" (implementation) +// types. They are favored over explicit casts because they +// help avoid making incorrect conversions (e.g., when using +// `reinterpret_cast` or C-style casts), and because they +// abstract over the conversion required for each pair of types. +// + +SLANG_FORCE_INLINE slang::IGlobalSession* asExternal(Session* session) +{ + return static_cast<slang::IGlobalSession*>(session); +} + +SLANG_FORCE_INLINE ComPtr<Session> asInternal(slang::IGlobalSession* session) +{ + Slang::Session* internalSession = nullptr; + session->queryInterface(SLANG_IID_PPV_ARGS(&internalSession)); + return ComPtr<Session>(INIT_ATTACH, static_cast<Session*>(internalSession)); +} + +SLANG_FORCE_INLINE slang::ISession* asExternal(Linkage* linkage) +{ + return static_cast<slang::ISession*>(linkage); +} + +SLANG_FORCE_INLINE Module* asInternal(slang::IModule* module) +{ + return static_cast<Module*>(module); +} + +SLANG_FORCE_INLINE slang::IModule* asExternal(Module* module) +{ + return static_cast<slang::IModule*>(module); +} + +ComponentType* asInternal(slang::IComponentType* inComponentType); + +SLANG_FORCE_INLINE slang::IComponentType* asExternal(ComponentType* componentType) +{ + return static_cast<slang::IComponentType*>(componentType); +} + +SLANG_FORCE_INLINE slang::ProgramLayout* asExternal(ProgramLayout* programLayout) +{ + return (slang::ProgramLayout*)programLayout; +} + +SLANG_FORCE_INLINE Type* asInternal(slang::TypeReflection* type) +{ + return reinterpret_cast<Type*>(type); +} + +SLANG_FORCE_INLINE slang::TypeReflection* asExternal(Type* type) +{ + return reinterpret_cast<slang::TypeReflection*>(type); +} + +SLANG_FORCE_INLINE DeclRef<Decl> asInternal(slang::GenericReflection* generic) +{ + return DeclRef<Decl>(reinterpret_cast<DeclRefBase*>(generic)); +} + +SLANG_FORCE_INLINE slang::GenericReflection* asExternal(DeclRef<Decl> generic) +{ + return reinterpret_cast<slang::GenericReflection*>(generic.declRefBase); +} + +SLANG_FORCE_INLINE TypeLayout* asInternal(slang::TypeLayoutReflection* type) +{ + return reinterpret_cast<TypeLayout*>(type); +} + +SLANG_FORCE_INLINE slang::TypeLayoutReflection* asExternal(TypeLayout* type) +{ + return reinterpret_cast<slang::TypeLayoutReflection*>(type); +} + +SLANG_FORCE_INLINE SlangCompileRequest* asExternal(EndToEndCompileRequest* request) +{ + return static_cast<SlangCompileRequest*>(request); +} + +SLANG_FORCE_INLINE EndToEndCompileRequest* asInternal(SlangCompileRequest* request) +{ + // Converts to the internal type -- does a runtime type check through queryInterfae + SLANG_ASSERT(request); + EndToEndCompileRequest* endToEndRequest = nullptr; + // NOTE! We aren't using to access an interface, so *doesn't* return with a refcount + request->queryInterface(SLANG_IID_PPV_ARGS(&endToEndRequest)); + SLANG_ASSERT(endToEndRequest); + return endToEndRequest; +} + +SLANG_FORCE_INLINE SlangCompileTarget asExternal(CodeGenTarget target) +{ + return (SlangCompileTarget)target; +} + +SLANG_FORCE_INLINE SlangSourceLanguage asExternal(SourceLanguage sourceLanguage) +{ + return (SlangSourceLanguage)sourceLanguage; +} + +} // namespace Slang diff --git a/source/slang/slang-compiler-fwd.h b/source/slang/slang-compiler-fwd.h new file mode 100644 index 000000000..91f3e1236 --- /dev/null +++ b/source/slang/slang-compiler-fwd.h @@ -0,0 +1,31 @@ +// slang-compiler-fwd.h +#pragma once + +// +// This file provides forward declarations that are +// commonly used by the files that get included by +// as part of the umbrella header `slang-compiler.h`. +// + +namespace Slang +{ +class ASTBuilder; +class EndToEndCompileRequest; +class FrontEndCompileRequest; +struct IRModule; +class Linkage; +class Module; +struct ModuleChunk; +class ProgramLayout; +class Session; +class SharedASTBuilder; +struct SharedSemanticsContext; +class TargetProgram; +class TargetRequest; +class TranslationUnitRequest; +struct TypeCheckingCache; +class TypeLayout; + +using LoadedModule = Module; + +} // namespace Slang diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index b13624aac..4b689455b 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -1,1223 +1,9 @@ -// Compiler.cpp : Defines the entry point for the console application. -// +// slang-compiler.cpp #include "slang-compiler.h" -#include "../compiler-core/slang-lexer.h" -#include "../core/slang-basic.h" -#include "../core/slang-castable.h" -#include "../core/slang-hex-dump-util.h" -#include "../core/slang-io.h" -#include "../core/slang-performance-profiler.h" -#include "../core/slang-platform.h" -#include "../core/slang-riff.h" -#include "../core/slang-string-util.h" -#include "../core/slang-type-convert-util.h" -#include "../core/slang-type-text-util.h" -#include "slang-check-impl.h" -#include "slang-check.h" - -#include <chrono> - -// Artifact -#include "../compiler-core/slang-artifact-associated.h" -#include "../compiler-core/slang-artifact-container-util.h" -#include "../compiler-core/slang-artifact-desc-util.h" -#include "../compiler-core/slang-artifact-diagnostic-util.h" -#include "../compiler-core/slang-artifact-impl.h" -#include "../compiler-core/slang-artifact-representation-impl.h" -#include "../compiler-core/slang-artifact-util.h" - -// Artifact output -#include "slang-artifact-output-util.h" -#include "slang-emit-cuda.h" -#include "slang-extension-tracker.h" -#include "slang-lower-to-ir.h" -#include "slang-mangle.h" -#include "slang-parameter-binding.h" -#include "slang-parser.h" -#include "slang-preprocessor.h" -#include "slang-serialize-ast.h" -#include "slang-serialize-container.h" -#include "slang-type-layout.h" - namespace Slang { -// !!!!!!!!!!!!!!!!!!!!!! free functions for DiagnosicSink !!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -bool isHeterogeneousTarget(CodeGenTarget target) -{ - return ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)).style == - ArtifactStyle::Host; -} - -void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) -{ - UnownedStringSlice name = TypeTextUtil::getCompileTargetName(asExternal(val)); - name = name.getLength() ? name : toSlice("<unknown>"); - sb << name; -} - -void printDiagnosticArg(StringBuilder& sb, PassThroughMode val) -{ - sb << TypeTextUtil::getPassThroughName(SlangPassThrough(val)); -} - -// -// FrontEndEntryPointRequest -// - -FrontEndEntryPointRequest::FrontEndEntryPointRequest( - FrontEndCompileRequest* compileRequest, - int translationUnitIndex, - Name* name, - Profile profile) - : m_compileRequest(compileRequest) - , m_translationUnitIndex(translationUnitIndex) - , m_name(name) - , m_profile(profile) -{ -} - - -TranslationUnitRequest* FrontEndEntryPointRequest::getTranslationUnit() -{ - return getCompileRequest()->translationUnits[m_translationUnitIndex]; -} - -// -// EntryPoint -// - -ISlangUnknown* EntryPoint::getInterface(const Guid& guid) -{ - if (guid == slang::IEntryPoint::getTypeGuid()) - return static_cast<slang::IEntryPoint*>(this); - - return Super::getInterface(guid); -} - -RefPtr<EntryPoint> EntryPoint::create( - Linkage* linkage, - DeclRef<FuncDecl> funcDeclRef, - Profile profile) -{ - RefPtr<EntryPoint> entryPoint = - new EntryPoint(linkage, funcDeclRef.getName(), profile, funcDeclRef); - entryPoint->m_mangledName = getMangledName(linkage->getASTBuilder(), funcDeclRef); - return entryPoint; -} - -RefPtr<EntryPoint> EntryPoint::createDummyForPassThrough( - Linkage* linkage, - Name* name, - Profile profile) -{ - RefPtr<EntryPoint> entryPoint = new EntryPoint(linkage, name, profile, DeclRef<FuncDecl>()); - return entryPoint; -} - -RefPtr<EntryPoint> EntryPoint::createDummyForDeserialize( - Linkage* linkage, - Name* name, - Profile profile, - String mangledName) -{ - RefPtr<EntryPoint> entryPoint = new EntryPoint(linkage, name, profile, DeclRef<FuncDecl>()); - entryPoint->m_mangledName = mangledName; - return entryPoint; -} - -EntryPoint::EntryPoint(Linkage* linkage, Name* name, Profile profile, DeclRef<FuncDecl> funcDeclRef) - : ComponentType(linkage), m_name(name), m_profile(profile), m_funcDeclRef(funcDeclRef) -{ - // Collect any specialization parameters used by the entry point - // - _collectShaderParams(); -} - -Module* EntryPoint::getModule() -{ - return Slang::getModule(getFuncDecl()); -} - -Index EntryPoint::getSpecializationParamCount() -{ - return m_genericSpecializationParams.getCount() + m_existentialSpecializationParams.getCount(); -} - -SpecializationParam const& EntryPoint::getSpecializationParam(Index index) -{ - auto genericParamCount = m_genericSpecializationParams.getCount(); - if (index < genericParamCount) - { - return m_genericSpecializationParams[index]; - } - else - { - return m_existentialSpecializationParams[index - genericParamCount]; - } -} - -Index EntryPoint::getRequirementCount() -{ - // The only requirement of an entry point is the module that contains it. - // - // TODO: We will eventually want to support the case of an entry - // point nested in a `struct` type, in which case there should be - // a single requirement representing that outer type (so that multiple - // entry points nested under the same type can share the storage - // for parameters at that scope). - - // Note: the defensive coding is here because the - // "dummy" entry points we create for pass-through - // compilation will not have an associated module. - // - if (const auto module = getModule()) - { - return 1; - } - return 0; -} - -RefPtr<ComponentType> EntryPoint::getRequirement(Index index) -{ - SLANG_UNUSED(index); - SLANG_ASSERT(index == 0); - SLANG_ASSERT(getModule()); - return getModule(); -} - -String EntryPoint::getEntryPointMangledName(Index index) -{ - SLANG_UNUSED(index); - SLANG_ASSERT(index == 0); - - return m_mangledName; -} - -String EntryPoint::getEntryPointNameOverride(Index index) -{ - SLANG_UNUSED(index); - SLANG_ASSERT(index == 0); - - return m_name ? m_name->text : ""; -} - -void EntryPoint::acceptVisitor( - ComponentTypeVisitor* visitor, - SpecializationInfo* specializationInfo) -{ - visitor->visitEntryPoint(this, as<EntryPointSpecializationInfo>(specializationInfo)); -} - -void EntryPoint::buildHash(DigestBuilder<SHA1>& builder) -{ - SLANG_UNUSED(builder); -} - -List<Module*> const& EntryPoint::getModuleDependencies() -{ - if (auto module = getModule()) - return module->getModuleDependencies(); - - static List<Module*> empty; - return empty; -} - -List<SourceFile*> const& EntryPoint::getFileDependencies() -{ - if (const auto module = getModule()) - return getModule()->getFileDependencies(); - - static List<SourceFile*> empty; - return empty; -} - -TypeConformance::TypeConformance( - Linkage* linkage, - SubtypeWitness* witness, - Int confomrmanceIdOverride, - DiagnosticSink* sink) - : ComponentType(linkage) - , m_subtypeWitness(witness) - , m_conformanceIdOverride(confomrmanceIdOverride) -{ - addDepedencyFromWitness(witness); - m_irModule = generateIRForTypeConformance(this, m_conformanceIdOverride, sink); -} - -void TypeConformance::addDepedencyFromWitness(SubtypeWitness* witness) -{ - if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness)) - { - auto declModule = getModule(declaredWitness->getDeclRef().getDecl()); - m_moduleDependencyList.addDependency(declModule); - m_fileDependencyList.addDependency(declModule); - if (m_requirementSet.add(declModule)) - { - m_requirements.add(declModule); - } - // TODO: handle the specialization arguments in declaredWitness->declRef.substitutions. - } - else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness)) - { - addDepedencyFromWitness(transitiveWitness->getMidToSup()); - addDepedencyFromWitness(transitiveWitness->getSubToMid()); - } - else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness)) - { - auto componentCount = conjunctionWitness->getComponentCount(); - for (Index i = 0; i < componentCount; ++i) - { - auto w = as<SubtypeWitness>(conjunctionWitness->getComponentWitness(i)); - if (w) - addDepedencyFromWitness(w); - } - } -} - -ISlangUnknown* TypeConformance::getInterface(const Guid& guid) -{ - if (guid == slang::ITypeConformance::getTypeGuid()) - return static_cast<slang::ITypeConformance*>(this); - - return Super::getInterface(guid); -} - -void TypeConformance::buildHash(DigestBuilder<SHA1>& builder) -{ - // TODO: Implement some kind of hashInto for Val then replace this - auto subtypeWitness = m_subtypeWitness->toString(); - - builder.append(subtypeWitness); - builder.append(m_conformanceIdOverride); -} - -List<Module*> const& TypeConformance::getModuleDependencies() -{ - return m_moduleDependencyList.getModuleList(); -} - -List<SourceFile*> const& TypeConformance::getFileDependencies() -{ - return m_fileDependencyList.getFileList(); -} - -Index TypeConformance::getRequirementCount() -{ - return m_requirements.getCount(); -} - -RefPtr<ComponentType> TypeConformance::getRequirement(Index index) -{ - return m_requirements[index]; -} - -void TypeConformance::acceptVisitor( - ComponentTypeVisitor* visitor, - ComponentType::SpecializationInfo* specializationInfo) -{ - SLANG_UNUSED(specializationInfo); - visitor->visitTypeConformance(this); -} - -RefPtr<ComponentType::SpecializationInfo> TypeConformance::_validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) -{ - SLANG_UNUSED(args); - SLANG_UNUSED(argCount); - SLANG_UNUSED(sink); - return nullptr; -} - -// - -Profile Profile::lookUp(UnownedStringSlice const& name) -{ -#define PROFILE(TAG, NAME, STAGE, VERSION) \ - if (name == UnownedTerminatedStringSlice(#NAME)) \ - return Profile::TAG; -#define PROFILE_ALIAS(TAG, DEF, NAME) \ - if (name == UnownedTerminatedStringSlice(#NAME)) \ - return Profile::TAG; -#include "slang-profile-defs.h" - - return Profile::Unknown; -} - -Profile Profile::lookUp(char const* name) -{ - return lookUp(UnownedTerminatedStringSlice(name)); -} - -CapabilitySet Profile::getCapabilityName() -{ - List<CapabilityName> result; - switch (getVersion()) - { -#define PROFILE_VERSION(TAG, NAME) \ - case ProfileVersion::TAG: \ - result.add(CapabilityName::TAG); \ - break; -#include "slang-profile-defs.h" - default: - break; - } - switch (getStage()) - { -#define PROFILE_STAGE(TAG, NAME, VAL) \ - case Stage::TAG: \ - result.add(CapabilityName::NAME); \ - break; -#include "slang-profile-defs.h" - default: - break; - } - - CapabilitySet resultSet = CapabilitySet(result); - for (auto i : this->additionalCapabilities) - resultSet.join(i); - return resultSet; -} - -char const* Profile::getName() -{ - switch (raw) - { - default: - return "unknown"; - -#define PROFILE(TAG, NAME, STAGE, VERSION) \ - case Profile::TAG: \ - return #NAME; -#define PROFILE_ALIAS(TAG, DEF, NAME) /* empty */ -#include "slang-profile-defs.h" - } -} - -static const StageInfo kStages[] = { -#define PROFILE_STAGE(ID, NAME, ENUM) {#NAME, Stage::ID}, - -#define PROFILE_STAGE_ALIAS(ID, NAME, VAL) {#NAME, Stage::ID}, - -#include "slang-profile-defs.h" -}; - -ConstArrayView<StageInfo> getStageInfos() -{ - return makeConstArrayView(kStages); -} - -Stage findStageByName(String const& name) -{ - for (auto entry : kStages) - { - if (name == entry.name) - { - return entry.stage; - } - } - - return Stage::Unknown; -} - -UnownedStringSlice getStageText(Stage stage) -{ - for (auto entry : kStages) - { - if (stage == entry.stage) - { - return UnownedStringSlice(entry.name); - } - } - return UnownedStringSlice(); -} - -Stage getStageFromAtom(CapabilityAtom atom) -{ - switch (atom) - { - case CapabilityAtom::vertex: - return Stage::Vertex; - case CapabilityAtom::hull: - return Stage::Hull; - case CapabilityAtom::domain: - return Stage::Domain; - case CapabilityAtom::geometry: - return Stage::Geometry; - case CapabilityAtom::fragment: - return Stage::Fragment; - case CapabilityAtom::compute: - return Stage::Compute; - case CapabilityAtom::_mesh: - return Stage::Mesh; - case CapabilityAtom::_amplification: - return Stage::Amplification; - case CapabilityAtom::_anyhit: - return Stage::AnyHit; - case CapabilityAtom::_closesthit: - return Stage::ClosestHit; - case CapabilityAtom::_intersection: - return Stage::Intersection; - case CapabilityAtom::_raygen: - return Stage::RayGeneration; - case CapabilityAtom::_miss: - return Stage::Miss; - case CapabilityAtom::_callable: - return Stage::Callable; - case CapabilityAtom::dispatch: - return Stage::Dispatch; - default: - SLANG_UNEXPECTED("unknown stage atom"); - UNREACHABLE_RETURN(Stage::Unknown); - } -} - -CapabilityAtom getAtomFromStage(Stage stage) -{ - // Convert Slang::Stage to CapabilityAtom. - // Note that capabilities do not share the same values as Slang::Stage - // and must be explicitly converted. - switch (stage) - { - case Stage::Compute: - return CapabilityAtom::compute; - case Stage::Vertex: - return CapabilityAtom::vertex; - case Stage::Fragment: - return CapabilityAtom::fragment; - case Stage::Geometry: - return CapabilityAtom::geometry; - case Stage::Hull: - return CapabilityAtom::hull; - case Stage::Domain: - return CapabilityAtom::domain; - case Stage::Mesh: - return CapabilityAtom::_mesh; - case Stage::Amplification: - return CapabilityAtom::_amplification; - case Stage::RayGeneration: - return CapabilityAtom::_raygen; - case Stage::AnyHit: - return CapabilityAtom::_anyhit; - case Stage::ClosestHit: - return CapabilityAtom::_closesthit; - case Stage::Miss: - return CapabilityAtom::_miss; - case Stage::Intersection: - return CapabilityAtom::_intersection; - case Stage::Callable: - return CapabilityAtom::_callable; - case Stage::Dispatch: - return CapabilityAtom::dispatch; - default: - SLANG_UNEXPECTED("unknown stage"); - UNREACHABLE_RETURN(CapabilityAtom::Invalid); - } -} - -SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough) -{ - // Check if the type is supported on this compile - if (passThrough == PassThroughMode::None) - { - // If no pass through -> that will always work! - return SLANG_OK; - } - - return session->getOrLoadDownstreamCompiler(passThrough, nullptr) ? SLANG_OK - : SLANG_E_NOT_FOUND; -} - -SourceLanguage getDefaultSourceLanguageForDownstreamCompiler(PassThroughMode compiler) -{ - switch (compiler) - { - case PassThroughMode::None: - { - return SourceLanguage::Unknown; - } - case PassThroughMode::Fxc: - case PassThroughMode::Dxc: - { - return SourceLanguage::HLSL; - } - case PassThroughMode::Glslang: - { - return SourceLanguage::GLSL; - } - case PassThroughMode::LLVM: - case PassThroughMode::Clang: - case PassThroughMode::VisualStudio: - case PassThroughMode::Gcc: - case PassThroughMode::GenericCCpp: - { - // These could ingest C, but we only have this function to work out a - // 'default' language to ingest. - return SourceLanguage::CPP; - } - case PassThroughMode::NVRTC: - { - return SourceLanguage::CUDA; - } - case PassThroughMode::Tint: - { - return SourceLanguage::WGSL; - } - case PassThroughMode::SpirvDis: - { - return SourceLanguage::SPIRV; - } - case PassThroughMode::MetalC: - { - return SourceLanguage::Metal; - } - default: - break; - } - SLANG_ASSERT(!"Unknown compiler"); - return SourceLanguage::Unknown; -} - -PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target) -{ - switch (target) - { - // Don't *require* a downstream compiler for source output - case CodeGenTarget::GLSL: - case CodeGenTarget::HLSL: - case CodeGenTarget::CUDASource: - case CodeGenTarget::CPPSource: - case CodeGenTarget::HostCPPSource: - case CodeGenTarget::PyTorchCppBinding: - case CodeGenTarget::CSource: - case CodeGenTarget::Metal: - case CodeGenTarget::WGSL: - { - return PassThroughMode::None; - } - case CodeGenTarget::None: - { - return PassThroughMode::None; - } - case CodeGenTarget::WGSLSPIRVAssembly: - case CodeGenTarget::SPIRVAssembly: - case CodeGenTarget::SPIRV: - { - return PassThroughMode::SpirvDis; - } - case CodeGenTarget::DXBytecode: - case CodeGenTarget::DXBytecodeAssembly: - { - return PassThroughMode::Fxc; - } - case CodeGenTarget::DXIL: - case CodeGenTarget::DXILAssembly: - { - return PassThroughMode::Dxc; - } - case CodeGenTarget::MetalLib: - case CodeGenTarget::MetalLibAssembly: - { - return PassThroughMode::MetalC; - } - case CodeGenTarget::ShaderHostCallable: - case CodeGenTarget::ShaderSharedLibrary: - case CodeGenTarget::HostExecutable: - case CodeGenTarget::HostHostCallable: - case CodeGenTarget::HostSharedLibrary: - { - // We need some C/C++ compiler - return PassThroughMode::GenericCCpp; - } - case CodeGenTarget::PTX: - { - return PassThroughMode::NVRTC; - } - case CodeGenTarget::WGSLSPIRV: - { - return PassThroughMode::Tint; - } - default: - break; - } - - SLANG_ASSERT(!"Unhandled target"); - return PassThroughMode::None; -} - -EndToEndCompileRequest* CodeGenContext::isPassThroughEnabled() -{ - auto endToEndReq = isEndToEndCompile(); - - // If there isn't an end-to-end compile going on, - // there can be no pass-through. - // - if (!endToEndReq) - return nullptr; - - // And if pass-through isn't set on that end-to-end compile, - // then we clearly areb't doing a pass-through compile. - // - if (endToEndReq->m_passThrough == PassThroughMode::None) - return nullptr; - - // If we have confirmed that pass-through compilation is going on, - // we return the end-to-end request, because it has all the - // relevant state that we need to implement pass-through mode. - // - return endToEndReq; -} - -/// If there is a pass-through compile going on, find the translation unit for the given entry -/// point. Assumes isPassThroughEnabled has already been called -TranslationUnitRequest* getPassThroughTranslationUnit( - EndToEndCompileRequest* endToEndReq, - Int entryPointIndex) -{ - SLANG_ASSERT(endToEndReq); - SLANG_ASSERT(endToEndReq->m_passThrough != PassThroughMode::None); - auto frontEndReq = endToEndReq->getFrontEndReq(); - auto entryPointReq = frontEndReq->getEntryPointReq(entryPointIndex); - auto translationUnit = entryPointReq->getTranslationUnit(); - return translationUnit; -} - -TranslationUnitRequest* CodeGenContext::findPassThroughTranslationUnit(Int entryPointIndex) -{ - if (auto endToEndReq = isPassThroughEnabled()) - return getPassThroughTranslationUnit(endToEndReq, entryPointIndex); - return nullptr; -} - -static void _appendCodeWithPath( - const UnownedStringSlice& filePath, - const UnownedStringSlice& fileContent, - StringBuilder& outCodeBuilder) -{ - outCodeBuilder << "#line 1 \""; - auto handler = StringEscapeUtil::getHandler(StringEscapeUtil::Style::Cpp); - handler->appendEscaped(filePath, outCodeBuilder); - outCodeBuilder << "\"\n"; - outCodeBuilder << fileContent << "\n"; -} - -void trackGLSLTargetCaps(ShaderExtensionTracker* extensionTracker, CapabilitySet const& caps) -{ - for (auto& conjunctions : caps.getAtomSets()) - { - for (auto atom : conjunctions) - { - switch (asAtom(atom)) - { - default: - break; - - case CapabilityAtom::glsl_spirv_1_0: - extensionTracker->requireSPIRVVersion(SemanticVersion(1, 0)); - break; - case CapabilityAtom::glsl_spirv_1_1: - extensionTracker->requireSPIRVVersion(SemanticVersion(1, 1)); - break; - case CapabilityAtom::glsl_spirv_1_2: - extensionTracker->requireSPIRVVersion(SemanticVersion(1, 2)); - break; - case CapabilityAtom::glsl_spirv_1_3: - extensionTracker->requireSPIRVVersion(SemanticVersion(1, 3)); - break; - case CapabilityAtom::glsl_spirv_1_4: - extensionTracker->requireSPIRVVersion(SemanticVersion(1, 4)); - break; - case CapabilityAtom::glsl_spirv_1_5: - extensionTracker->requireSPIRVVersion(SemanticVersion(1, 5)); - break; - case CapabilityAtom::glsl_spirv_1_6: - extensionTracker->requireSPIRVVersion(SemanticVersion(1, 6)); - break; - } - } - } -} - -SlangResult CodeGenContext::requireTranslationUnitSourceFiles() -{ - if (auto endToEndReq = isPassThroughEnabled()) - { - for (auto entryPointIndex : getEntryPointIndices()) - { - auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndex); - SLANG_ASSERT(translationUnit); - /// Make sure we have the source files - SLANG_RETURN_ON_FAIL(translationUnit->requireSourceFiles()); - } - } - - return SLANG_OK; -} - -#if SLANG_VC -// TODO(JS): This is a workaround -// In debug VS builds there is a warning on line about it being unreachable. -// for (auto entryPointIndex : getEntryPointIndices()) -// It's not clear how that could possibly be unreachable -#pragma warning(push) -#pragma warning(disable : 4702) -#endif -SlangResult CodeGenContext::emitEntryPointsSource(ComPtr<IArtifact>& outArtifact) -{ - outArtifact.setNull(); - - SLANG_RETURN_ON_FAIL(requireTranslationUnitSourceFiles()); - - auto endToEndReq = isPassThroughEnabled(); - if (endToEndReq) - { - for (auto entryPointIndex : getEntryPointIndices()) - { - auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndex); - SLANG_ASSERT(translationUnit); - - /// Make sure we have the source files - SLANG_RETURN_ON_FAIL(translationUnit->requireSourceFiles()); - - // Generate a string that includes the content of - // the source file(s), along with a line directive - // to ensure that we get reasonable messages - // from the downstream compiler when in pass-through - // mode. - - StringBuilder codeBuilder; - if (getTargetFormat() == CodeGenTarget::GLSL) - { - // Special case GLSL - int translationUnitCounter = 0; - for (auto sourceFile : translationUnit->getSourceFiles()) - { - int translationUnitIndex = translationUnitCounter++; - - // We want to output `#line` directives, but we need - // to skip this for the first file, since otherwise - // some GLSL implementations will get tripped up by - // not having the `#version` directive be the first - // thing in the file. - if (translationUnitIndex != 0) - { - codeBuilder << "#line 1 " << translationUnitIndex << "\n"; - } - codeBuilder << sourceFile->getContent() << "\n"; - } - } - else - { - for (auto sourceFile : translationUnit->getSourceFiles()) - { - _appendCodeWithPath( - sourceFile->getPathInfo().foundPath.getUnownedSlice(), - sourceFile->getContent(), - codeBuilder); - } - } - - auto artifact = - ArtifactUtil::createArtifactForCompileTarget(asExternal(getTargetFormat())); - artifact->addRepresentationUnknown(StringBlob::moveCreate(codeBuilder)); - - outArtifact.swap(artifact); - return SLANG_OK; - } - return SLANG_OK; - } - else - { - return emitEntryPointsSourceFromIR(outArtifact); - } -} -#if SLANG_VC -#pragma warning(pop) -#endif - -SlangResult CodeGenContext::emitPrecompiledDownstreamIR(ComPtr<IArtifact>& outArtifact) -{ - return _emitEntryPoints(outArtifact); -} - -String GetHLSLProfileName(Profile profile) -{ - switch (profile.getFamily()) - { - case ProfileFamily::DX: - // Profile version is a DX one, so stick with it. - break; - - default: - // Profile is a non-DX profile family, so we need to try - // to clobber it with something to get a default. - // - // TODO: This is a huge hack... - profile.setVersion(ProfileVersion::DX_5_1); - break; - } - - char const* stagePrefix = nullptr; - switch (profile.getStage()) - { - // Note: All of the raytracing-related stages require - // compiling for a `lib_*` profile, even when only a - // single entry point is present. - // - // We also go ahead and use this target in any case - // where we don't know the actual stage to compiel for, - // as a fallback option. - // - // TODO: We also want to use this option when compiling - // multiple entry points to a DXIL library. - // - default: - stagePrefix = "lib"; - break; - - // The traditional rasterization pipeline and compute - // shaders all have custom profile names that identify - // both the stage and shader model, which need to be - // used when compiling a single entry point. - // -#define CASE(NAME, PREFIX) \ - case Stage::NAME: \ - stagePrefix = #PREFIX; \ - break - CASE(Vertex, vs); - CASE(Hull, hs); - CASE(Domain, ds); - CASE(Geometry, gs); - CASE(Fragment, ps); - CASE(Compute, cs); - CASE(Amplification, as); - CASE(Mesh, ms); -#undef CASE - } - - char const* versionSuffix = nullptr; - switch (profile.getVersion()) - { -#define CASE(TAG, SUFFIX) \ - case ProfileVersion::TAG: \ - versionSuffix = #SUFFIX; \ - break - CASE(DX_4_0, _4_0); - CASE(DX_4_1, _4_1); - CASE(DX_5_0, _5_0); - CASE(DX_5_1, _5_1); - CASE(DX_6_0, _6_0); - CASE(DX_6_1, _6_1); - CASE(DX_6_2, _6_2); - CASE(DX_6_3, _6_3); - CASE(DX_6_4, _6_4); - CASE(DX_6_5, _6_5); - CASE(DX_6_6, _6_6); - CASE(DX_6_7, _6_7); - CASE(DX_6_8, _6_8); - CASE(DX_6_9, _6_9); -#undef CASE - - default: - return "unknown"; - } - - String result; - result.append(stagePrefix); - result.append(versionSuffix); - return result; -} - -void reportExternalCompileError( - const char* compilerName, - Severity severity, - SlangResult res, - const UnownedStringSlice& diagnostic, - DiagnosticSink* sink) -{ - StringBuilder builder; - if (compilerName) - { - builder << compilerName << ": "; - } - - if (SLANG_FAILED(res) && res != SLANG_FAIL) - { - { - char tmp[17]; - sprintf_s(tmp, SLANG_COUNT_OF(tmp), "0x%08x", uint32_t(res)); - builder << "Result(" << tmp << ") "; - } - - PlatformUtil::appendResult(res, builder); - } - - if (diagnostic.getLength() > 0) - { - builder.append(diagnostic); - if (!diagnostic.endsWith("\n")) - { - builder.append("\n"); - } - } - - sink->diagnoseRaw(severity, builder.getUnownedSlice()); -} - -void reportExternalCompileError( - const char* compilerName, - SlangResult res, - const UnownedStringSlice& diagnostic, - DiagnosticSink* sink) -{ - // TODO(tfoley): need a better policy for how we translate diagnostics - // back into the Slang world (although we should always try to generate - // HLSL that doesn't produce any diagnostics...) - reportExternalCompileError( - compilerName, - SLANG_FAILED(res) ? Severity::Error : Severity::Warning, - res, - diagnostic, - sink); -} - -static String _getDisplayPath(DiagnosticSink* sink, SourceFile* sourceFile) -{ - if (sink->isFlagSet(DiagnosticSink::Flag::VerbosePath)) - { - return sourceFile->calcVerbosePath(); - } - else - { - return sourceFile->getPathInfo().foundPath; - } -} - -String CodeGenContext::calcSourcePathForEntryPoints() -{ - String failureMode = "slang-generated"; - if (getEntryPointCount() != 1) - return failureMode; - auto entryPointIndex = getSingleEntryPointIndex(); - auto translationUnitRequest = findPassThroughTranslationUnit(entryPointIndex); - if (!translationUnitRequest) - return failureMode; - - const auto& sourceFiles = translationUnitRequest->getSourceFiles(); - - auto sink = getSink(); - - const Index numSourceFiles = sourceFiles.getCount(); - - switch (numSourceFiles) - { - case 0: - return "unknown"; - case 1: - return _getDisplayPath(sink, sourceFiles[0]); - default: - { - StringBuilder builder; - builder << _getDisplayPath(sink, sourceFiles[0]); - for (int i = 1; i < numSourceFiles; ++i) - { - builder << ";" << _getDisplayPath(sink, sourceFiles[i]); - } - return builder; - } - } -} - -// Helper function for cases where we can assume a single entry point -Int assertSingleEntryPoint(List<Int> const& entryPointIndices) -{ - SLANG_ASSERT(entryPointIndices.getCount() == 1); - return *entryPointIndices.begin(); -} - -// True if it's best to use 'emitted' source for complication. For a downstream compiler -// that is not file based, this is always ok. -/// -/// If the downstream compiler is file system based, we may want to just use the file that was -/// passed to be compiled. That the downstream compiler can determine if it will then save the file -/// or not based on if it's a match - and generally there will not be a match with emitted source. -/// -/// This test is only used for pass through mode. -static bool _useEmittedSource( - IDownstreamCompiler* compiler, - TranslationUnitRequest* translationUnit) -{ - // We only bother if it's a file based compiler. - if (compiler->isFileBased()) - { - // It can only have *one* source file as otherwise we have to combine to make a new source - // file anyway - return translationUnit->getSourceArtifacts().getCount() != 1; - } - return true; -} - -static Severity _getDiagnosticSeverity(ArtifactDiagnostic::Severity severity) -{ - switch (severity) - { - case ArtifactDiagnostic::Severity::Warning: - return Severity::Warning; - case ArtifactDiagnostic::Severity::Info: - return Severity::Note; - default: - return Severity::Error; - } -} - -static RefPtr<ExtensionTracker> _newExtensionTracker(CodeGenTarget target) -{ - switch (target) - { - case CodeGenTarget::PTX: - case CodeGenTarget::CUDASource: - { - return new CUDAExtensionTracker; - } - case CodeGenTarget::SPIRV: - case CodeGenTarget::GLSL: - case CodeGenTarget::WGSL: - case CodeGenTarget::WGSLSPIRV: - case CodeGenTarget::WGSLSPIRVAssembly: - { - return new ShaderExtensionTracker; - } - default: - return nullptr; - } -} - -static CodeGenTarget _getDefaultSourceForTarget(CodeGenTarget target) -{ - switch (target) - { - case CodeGenTarget::ShaderHostCallable: - case CodeGenTarget::ShaderSharedLibrary: - { - return CodeGenTarget::CPPSource; - } - case CodeGenTarget::HostHostCallable: - case CodeGenTarget::HostExecutable: - case CodeGenTarget::HostSharedLibrary: - { - return CodeGenTarget::HostCPPSource; - } - case CodeGenTarget::PTX: - return CodeGenTarget::CUDASource; - case CodeGenTarget::DXBytecode: - return CodeGenTarget::HLSL; - case CodeGenTarget::DXIL: - return CodeGenTarget::HLSL; - case CodeGenTarget::SPIRV: - return CodeGenTarget::GLSL; - case CodeGenTarget::MetalLib: - return CodeGenTarget::Metal; - case CodeGenTarget::WGSLSPIRV: - return CodeGenTarget::WGSL; - default: - break; - } - return CodeGenTarget::Unknown; -} - -static bool _isCPUHostTarget(CodeGenTarget target) -{ - auto desc = ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)); - return desc.style == ArtifactStyle::Host; -} - -static bool _shouldSetEntryPointName(TargetProgram* targetProgram) -{ - if (!isKhronosTarget(targetProgram->getTargetReq())) - return true; - if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::VulkanUseEntryPointName)) - return true; - return false; -} - -SlangResult passthroughDownstreamDiagnostics( - DiagnosticSink* sink, - IDownstreamCompiler* compiler, - IArtifact* artifact) -{ - auto diagnostics = findAssociatedRepresentation<IArtifactDiagnostics>(artifact); - - if (!diagnostics) - return SLANG_OK; - - if (diagnostics->getCount()) - { - StringBuilder compilerText; - DownstreamCompilerUtil::appendAsText(compiler->getDesc(), compilerText); - - StringBuilder builder; - - auto const diagnosticCount = diagnostics->getCount(); - for (Index i = 0; i < diagnosticCount; ++i) - { - const auto& diagnostic = *diagnostics->getAt(i); - - builder.clear(); - - const Severity severity = _getDiagnosticSeverity(diagnostic.severity); - - if (diagnostic.filePath.count == 0 && diagnostic.location.line == 0 && - severity == Severity::Note) - { - // If theres no filePath line number and it's info, output severity and text alone - builder << getSeverityName(severity) << " : "; - } - else - { - if (diagnostic.filePath.count) - { - builder << asStringSlice(diagnostic.filePath); - } - - if (diagnostic.location.line) - { - builder << "(" << diagnostic.location.line << ")"; - } - - builder << ": "; - - if (diagnostic.stage == ArtifactDiagnostic::Stage::Link) - { - builder << "link "; - } - - builder << getSeverityName(severity); - builder << " " << asStringSlice(diagnostic.code) << ": "; - } - - builder << asStringSlice(diagnostic.text); - reportExternalCompileError( - compilerText.getBuffer(), - severity, - SLANG_OK, - builder.getUnownedSlice(), - sink); - } - } - - // If any errors are emitted, then we are done - if (diagnostics->hasOfAtLeastSeverity(ArtifactDiagnostic::Severity::Error)) - { - return SLANG_FAIL; - } - - return SLANG_OK; -} - bool isValidSlangLanguageVersion(SlangLanguageVersion version) { switch (version) @@ -1258,1885 +44,4 @@ bool isValidGLSLVersion(int version) } } -SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr<IArtifact>& outArtifact) -{ - outArtifact.setNull(); - - auto sink = getSink(); - auto session = getSession(); - - CodeGenTarget sourceTarget = CodeGenTarget::None; - SourceLanguage sourceLanguage = SourceLanguage::Unknown; - - auto target = getTargetFormat(); - RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target); - PassThroughMode compilerType; - - SliceAllocator allocator; - - if (auto endToEndReq = isPassThroughEnabled()) - { - compilerType = endToEndReq->m_passThrough; - } - else - { - // If we are not in pass through, lookup the default compiler for the emitted source type - - // Get the default source codegen type for a given target - sourceTarget = _getDefaultSourceForTarget(target); - compilerType = (PassThroughMode)session->getDownstreamCompilerForTransition( - (SlangCompileTarget)sourceTarget, - (SlangCompileTarget)target); - // We should have a downstream compiler set at this point - if (compilerType == PassThroughMode::None) - { - auto sourceName = TypeTextUtil::getCompileTargetName(SlangCompileTarget(sourceTarget)); - auto targetName = TypeTextUtil::getCompileTargetName(SlangCompileTarget(target)); - - sink->diagnose( - SourceLoc(), - Diagnostics::compilerNotDefinedForTransition, - sourceName, - targetName); - return SLANG_FAIL; - } - } - - SLANG_ASSERT(compilerType != PassThroughMode::None); - - // Get the required downstream compiler - IDownstreamCompiler* compiler = session->getOrLoadDownstreamCompiler(compilerType, sink); - if (!compiler) - { - auto compilerName = TypeTextUtil::getPassThroughAsHumanText((SlangPassThrough)compilerType); - sink->diagnose(SourceLoc(), Diagnostics::passThroughCompilerNotFound, compilerName); - return SLANG_FAIL; - } - - Dictionary<String, String> preprocessorDefinitions; - List<String> includePaths; - - typedef DownstreamCompileOptions CompileOptions; - CompileOptions options; - - List<DownstreamCompileOptions::CapabilityVersion> requiredCapabilityVersions; - List<String> compilerSpecificArguments; - List<ComPtr<IArtifact>> libraries; - List<String> libraryPaths; - - // Set compiler specific args - { - auto name = TypeTextUtil::getPassThroughName((SlangPassThrough)compilerType); - List<String> downstreamArgs = getTargetProgram()->getOptionSet().getDownstreamArgs(name); - for (const auto& arg : downstreamArgs) - { - // We special case some kinds of args, that can be handled directly - if (arg.startsWith("-I")) - { - // We handle the -I option, by just adding to the include paths - includePaths.add(arg.getUnownedSlice().tail(2)); - } - else - { - compilerSpecificArguments.add(arg); - } - } - } - - ComPtr<IArtifact> sourceArtifact; - - /* This is more convoluted than the other scenarios, because when we invoke C/C++ compiler we - would ideally like to use the original file. We want to do this because we want includes - relative to the source file to work, and for that to work most easily we want to use the - original file, if there is one */ - if (auto endToEndReq = isPassThroughEnabled()) - { - // If we are pass through, we may need to set extension tracker state. - if (ShaderExtensionTracker* glslTracker = as<ShaderExtensionTracker>(extensionTracker)) - { - trackGLSLTargetCaps(glslTracker, getTargetCaps()); - } - - auto translationUnit = - getPassThroughTranslationUnit(endToEndReq, getSingleEntryPointIndex()); - - // We are just passing thru, so it's whatever it originally was - sourceLanguage = translationUnit->sourceLanguage; - - // TODO(JS): This seems like a bit of a hack - // That if a pass-through is being performed and the source language is Slang - // no downstream compiler knows how to deal with that, so probably means 'HLSL' - sourceLanguage = - (sourceLanguage == SourceLanguage::Slang) ? SourceLanguage::HLSL : sourceLanguage; - sourceTarget = CodeGenTarget(TypeConvertUtil::getCompileTargetFromSourceLanguage( - (SlangSourceLanguage)sourceLanguage)); - - // If it's pass through we accumulate the preprocessor definitions. - for (const auto& define : - endToEndReq->getOptionSet().getArray(CompilerOptionName::MacroDefine)) - preprocessorDefinitions.add(define.stringValue, define.stringValue2); - for (const auto& define : translationUnit->preprocessorDefinitions) - preprocessorDefinitions.add(define); - - { - /* TODO(JS): Not totally clear what options should be set here. If we are using the pass - through - then using say the defines/includes all makes total sense. If we are - generating C++ code from slang, then should we really be using these values -> aren't - they what is being set for the *slang* source, not for the C++ generated code. That - being the case it implies that there needs to be a mechanism (if there isn't already) to - specify such information on a particular pass/pass through etc. - - On invoking DXC for example include paths do not appear to be set at all (even with - pass-through). - */ - - auto linkage = getLinkage(); - - // Add all the search paths - - const auto searchDirectories = linkage->getSearchDirectories(); - const SearchDirectoryList* searchList = &searchDirectories; - while (searchList) - { - for (const auto& searchDirectory : searchList->searchDirectories) - { - includePaths.add(searchDirectory.path); - } - searchList = searchList->parent; - } - } - - // If emitted source is required, emit and set the path - if (_useEmittedSource(compiler, translationUnit)) - { - CodeGenContext sourceCodeGenContext(this, sourceTarget, extensionTracker); - - SLANG_RETURN_ON_FAIL(sourceCodeGenContext.emitEntryPointsSource(sourceArtifact)); - - // If it's not file based we can set an appropriate path name, and it doesn't matter if - // it doesn't exist on the file system. We set the name to the path as this will be used - // for downstream reporting. - auto sourcePath = calcSourcePathForEntryPoints(); - sourceArtifact->setName(sourcePath.getBuffer()); - - sourceCodeGenContext.maybeDumpIntermediate(sourceArtifact); - } - else - { - // Special case if we have a single file, so that we pass the path, and the contents as - // is. - const auto& sourceArtifacts = translationUnit->getSourceArtifacts(); - SLANG_ASSERT(sourceArtifacts.getCount() == 1); - - sourceArtifact = sourceArtifacts[0]; - SLANG_ASSERT(sourceArtifact); - } - } - else - { - CodeGenContext sourceCodeGenContext(this, sourceTarget, extensionTracker); - - sourceCodeGenContext.removeAvailableInDownstreamIR = true; - - SLANG_RETURN_ON_FAIL(sourceCodeGenContext.emitEntryPointsSource(sourceArtifact)); - sourceCodeGenContext.maybeDumpIntermediate(sourceArtifact); - - sourceLanguage = (SourceLanguage)TypeConvertUtil::getSourceLanguageFromTarget( - (SlangCompileTarget)sourceTarget); - } - - if (sourceArtifact) - { - // Set the source artifacts - options.sourceArtifacts = makeSlice(sourceArtifact.readRef(), 1); - } - - // Add any preprocessor definitions associated with the linkage - { - // TODO(JS): This is somewhat arguable - should defines passed to Slang really be - // passed to downstream compilers? It does appear consistent with the behavior if - // there is an endToEndReq. - // - // That said it's very convenient and provides way to control aspects - // of downstream compilation. - - for (const auto& define : - getTargetProgram()->getOptionSet().getArray(CompilerOptionName::MacroDefine)) - { - preprocessorDefinitions.addIfNotExists(define.stringValue, define.stringValue2); - } - } - - - // If we have an extension tracker, we may need to set options such as SPIR-V version - // and CUDA Shader Model. - if (extensionTracker) - { - // Look for the version - if (auto cudaTracker = as<CUDAExtensionTracker>(extensionTracker)) - { - cudaTracker->finalize(); - - if (cudaTracker->m_smVersion.isSet()) - { - DownstreamCompileOptions::CapabilityVersion version; - version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::CUDASM; - version.version = cudaTracker->m_smVersion; - - requiredCapabilityVersions.add(version); - } - - if (cudaTracker->isBaseTypeRequired(BaseType::Half)) - { - options.flags |= CompileOptions::Flag::EnableFloat16; - } - } - else if (ShaderExtensionTracker* glslTracker = as<ShaderExtensionTracker>(extensionTracker)) - { - DownstreamCompileOptions::CapabilityVersion version; - version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::SPIRV; - version.version = glslTracker->getSPIRVVersion(); - - requiredCapabilityVersions.add(version); - } - } - - CapabilitySet targetCaps = getTargetCaps(); - for (auto atomSets : targetCaps.getAtomSets()) - { - for (auto atomVal : atomSets) - { - auto atom = CapabilityAtom(atomVal); - switch (atom) - { - default: - break; - -#define CASE(KIND, NAME, VERSION) \ - case CapabilityAtom::NAME: \ - requiredCapabilityVersions.add(DownstreamCompileOptions::CapabilityVersion{ \ - DownstreamCompileOptions::CapabilityVersion::Kind::KIND, \ - VERSION}); \ - break - - CASE(CUDASM, _cuda_sm_1_0, SemanticVersion(1, 0)); - CASE(CUDASM, _cuda_sm_2_0, SemanticVersion(2, 0)); - CASE(CUDASM, _cuda_sm_3_0, SemanticVersion(3, 0)); - CASE(CUDASM, _cuda_sm_4_0, SemanticVersion(4, 0)); - CASE(CUDASM, _cuda_sm_5_0, SemanticVersion(5, 0)); - CASE(CUDASM, _cuda_sm_6_0, SemanticVersion(6, 0)); - CASE(CUDASM, _cuda_sm_7_0, SemanticVersion(7, 0)); - CASE(CUDASM, _cuda_sm_8_0, SemanticVersion(8, 0)); - CASE(CUDASM, _cuda_sm_9_0, SemanticVersion(9, 0)); - -#undef CASE - } - } - } - - // Set the file sytem and source manager, as *may* be used by downstream compiler - options.fileSystemExt = getFileSystemExt(); - options.sourceManager = getSourceManager(); - - // Set the source type - options.sourceLanguage = SlangSourceLanguage(sourceLanguage); - - switch (target) - { - case CodeGenTarget::ShaderHostCallable: - case CodeGenTarget::ShaderSharedLibrary: - // Disable exceptions and security checks - options.flags &= - ~(CompileOptions::Flag::EnableExceptionHandling | - CompileOptions::Flag::EnableSecurityChecks); - break; - } - - Profile profile; - - if (compilerType == PassThroughMode::Fxc || compilerType == PassThroughMode::Dxc || - compilerType == PassThroughMode::Glslang) - { - const auto entryPointIndices = getEntryPointIndices(); - auto targetReq = getTargetReq(); - - const auto entryPointIndicesCount = entryPointIndices.getCount(); - - // Whole program means - // * can have 0-N entry points - // * 'doesn't build into an executable/kernel' - // - // So in some sense it is a library - if (getTargetProgram()->getOptionSet().getBoolOption( - CompilerOptionName::GenerateWholeProgram)) - { - if (compilerType == PassThroughMode::Dxc) - { - // Can support no entry points on DXC because we can build libraries - profile = - Profile(getTargetProgram()->getOptionSet().getEnumOption<Profile::RawEnum>( - CompilerOptionName::Profile)); - } - else - { - auto downstreamCompilerName = - TypeTextUtil::getPassThroughName((SlangPassThrough)compilerType); - - sink->diagnose( - SourceLoc(), - Diagnostics::downstreamCompilerDoesntSupportWholeProgramCompilation, - downstreamCompilerName); - return SLANG_FAIL; - } - } - else if (entryPointIndicesCount == 1) - { - // All support a single entry point - const Index entryPointIndex = entryPointIndices[0]; - - auto entryPoint = getEntryPoint(entryPointIndex); - profile = getEffectiveProfile(entryPoint, targetReq); - - if (_shouldSetEntryPointName(getTargetProgram())) - { - options.entryPointName = allocator.allocate(getText(entryPoint->getName())); - auto entryPointNameOverride = - getProgram()->getEntryPointNameOverride(entryPointIndex); - if (entryPointNameOverride.getLength() != 0) - { - options.entryPointName = allocator.allocate(entryPointNameOverride); - } - } - } - else - { - // We only support a single entry point on this target - SLANG_ASSERT(!"Can only compile with a single entry point on this target"); - return SLANG_FAIL; - } - - options.stage = SlangStage(profile.getStage()); - - if (compilerType == PassThroughMode::Dxc) - { - // We will enable the flag to generate proper code for 16 - bit types - // by default, as long as the user is requesting a sufficiently - // high shader model. - // - // TODO: Need to check that this is safe to enable in all cases, - // or if it will make a shader demand hardware features that - // aren't always present. - // - // TODO: Ideally the dxc back-end should be passed some information - // on the "capabilities" that were used and/or requested in the code. - // - if (profile.getVersion() >= ProfileVersion::DX_6_2) - { - options.flags |= CompileOptions::Flag::EnableFloat16; - } - - // Set the matrix layout - options.matrixLayout = - (SlangMatrixLayoutMode)getTargetProgram()->getOptionSet().getMatrixLayoutMode(); - } - - // Set the profile - options.profileName = allocator.allocate(GetHLSLProfileName(profile)); - } - - // If we aren't using LLVM 'host callable', we want downstream compile to produce a shared - // library - if (compilerType != PassThroughMode::LLVM && - ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)).kind == - ArtifactKind::HostCallable) - { - target = CodeGenTarget::ShaderSharedLibrary; - } - - if (!isPassThroughEnabled()) - { - if (_isCPUHostTarget(target)) - { - libraryPaths.add(Path::getParentDirectory(Path::getExecutablePath())); - libraryPaths.add( - Path::combine(Path::getParentDirectory(Path::getExecutablePath()), "../lib")); - - // Set up the library artifact - auto artifact = Artifact::create( - ArtifactDesc::make(ArtifactKind::Library, Artifact::Payload::HostCPU), - toSlice("slang-rt")); - - ComPtr<IOSFileArtifactRepresentation> fileRep(new OSFileArtifactRepresentation( - IOSFileArtifactRepresentation::Kind::NameOnly, - toSlice("slang-rt"), - nullptr)); - artifact->addRepresentation(fileRep); - - libraries.add(artifact); - } - } - - options.targetType = (SlangCompileTarget)target; - - // Need to configure for the compilation - - { - auto linkage = getLinkage(); - - switch (getTargetProgram()->getOptionSet().getEnumOption<OptimizationLevel>( - CompilerOptionName::Optimization)) - { - case OptimizationLevel::None: - options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::None; - break; - case OptimizationLevel::Default: - options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::Default; - break; - case OptimizationLevel::High: - options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::High; - break; - case OptimizationLevel::Maximal: - options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::Maximal; - break; - default: - SLANG_ASSERT(!"Unhandled optimization level"); - break; - } - - switch (getTargetProgram()->getOptionSet().getEnumOption<DebugInfoLevel>( - CompilerOptionName::DebugInformation)) - { - case DebugInfoLevel::None: - options.debugInfoType = DownstreamCompileOptions::DebugInfoType::None; - break; - case DebugInfoLevel::Minimal: - options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Minimal; - break; - - case DebugInfoLevel::Standard: - options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Standard; - break; - case DebugInfoLevel::Maximal: - options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Maximal; - break; - default: - SLANG_ASSERT(!"Unhandled debug level"); - break; - } - - switch (getTargetProgram()->getOptionSet().getEnumOption<FloatingPointMode>( - CompilerOptionName::FloatingPointMode)) - { - case FloatingPointMode::Default: - options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Default; - break; - case FloatingPointMode::Precise: - options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Precise; - break; - case FloatingPointMode::Fast: - options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Fast; - break; - default: - SLANG_ASSERT(!"Unhandled floating point mode"); - } - - if (getTargetProgram()->getOptionSet().hasOption(CompilerOptionName::DenormalModeFp16)) - { - switch (getTargetProgram()->getOptionSet().getEnumOption<FloatingPointDenormalMode>( - CompilerOptionName::DenormalModeFp16)) - { - case FloatingPointDenormalMode::Any: - options.denormalModeFp16 = DownstreamCompileOptions::FloatingPointDenormalMode::Any; - break; - case FloatingPointDenormalMode::Preserve: - options.denormalModeFp16 = - DownstreamCompileOptions::FloatingPointDenormalMode::Preserve; - break; - case FloatingPointDenormalMode::FlushToZero: - options.denormalModeFp16 = - DownstreamCompileOptions::FloatingPointDenormalMode::FlushToZero; - break; - default: - SLANG_ASSERT(!"Unhandled fp16 denormal handling mode"); - } - } - - if (getTargetProgram()->getOptionSet().hasOption(CompilerOptionName::DenormalModeFp32)) - { - switch (getTargetProgram()->getOptionSet().getEnumOption<FloatingPointDenormalMode>( - CompilerOptionName::DenormalModeFp32)) - { - case FloatingPointDenormalMode::Any: - options.denormalModeFp32 = DownstreamCompileOptions::FloatingPointDenormalMode::Any; - break; - case FloatingPointDenormalMode::Preserve: - options.denormalModeFp32 = - DownstreamCompileOptions::FloatingPointDenormalMode::Preserve; - break; - case FloatingPointDenormalMode::FlushToZero: - options.denormalModeFp32 = - DownstreamCompileOptions::FloatingPointDenormalMode::FlushToZero; - break; - default: - SLANG_ASSERT(!"Unhandled fp32 denormal handling mode"); - } - } - - if (getTargetProgram()->getOptionSet().hasOption(CompilerOptionName::DenormalModeFp64)) - { - switch (getTargetProgram()->getOptionSet().getEnumOption<FloatingPointDenormalMode>( - CompilerOptionName::DenormalModeFp64)) - { - case FloatingPointDenormalMode::Any: - options.denormalModeFp64 = DownstreamCompileOptions::FloatingPointDenormalMode::Any; - break; - case FloatingPointDenormalMode::Preserve: - options.denormalModeFp64 = - DownstreamCompileOptions::FloatingPointDenormalMode::Preserve; - break; - case FloatingPointDenormalMode::FlushToZero: - options.denormalModeFp64 = - DownstreamCompileOptions::FloatingPointDenormalMode::FlushToZero; - break; - default: - SLANG_ASSERT(!"Unhandled fp64 denormal handling mode"); - } - } - - { - // We need to look at the stage of the entry point(s) we are - // being asked to compile, since this will determine the - // "pipeline" that the result should be compiled for (e.g., - // compute vs. ray tracing). - // - // TODO: This logic is kind of messy in that it assumes - // a program to be compiled will only contain kernels for - // a single pipeline type, but that invariant isn't expressed - // at all in the front-end today. It also has no error - // checking for the case where there are conflicts. - // - // HACK: Right now none of the above concerns matter - // because we always perform code generation on a single - // entry point at a time. - // - Index entryPointCount = getEntryPointCount(); - for (Index ee = 0; ee < entryPointCount; ++ee) - { - auto stage = getEntryPoint(ee)->getStage(); - switch (stage) - { - default: - break; - - case Stage::Compute: - options.pipelineType = DownstreamCompileOptions::PipelineType::Compute; - break; - - case Stage::Vertex: - case Stage::Hull: - case Stage::Domain: - case Stage::Geometry: - case Stage::Fragment: - options.pipelineType = DownstreamCompileOptions::PipelineType::Rasterization; - break; - - case Stage::RayGeneration: - case Stage::Intersection: - case Stage::AnyHit: - case Stage::ClosestHit: - case Stage::Miss: - case Stage::Callable: - options.pipelineType = DownstreamCompileOptions::PipelineType::RayTracing; - break; - } - } - } - - // Add all the search paths (as calculated earlier - they will only be set if this is a pass - // through else will be empty) - options.includePaths = allocator.allocate(includePaths); - - // Add the specified defines (as calculated earlier - they will only be set if this is a - // pass through else will be empty) - { - const auto count = preprocessorDefinitions.getCount(); - auto dst = allocator.getArena().allocateArray<DownstreamCompileOptions::Define>(count); - - Index i = 0; - - for (const auto& [defKey, defValue] : preprocessorDefinitions) - { - auto& define = dst[i]; - - define.nameWithSig = allocator.allocate(defKey); - define.value = allocator.allocate(defValue); - - ++i; - } - options.defines = makeSlice(dst, count); - } - - // Add all of the module libraries - libraries.addRange(linkage->m_libModules.getBuffer(), linkage->m_libModules.getCount()); - } - - auto program = getProgram(); - - // Load embedded precompiled libraries from IR into library artifacts - program->enumerateIRModules( - [&](IRModule* irModule) - { - for (auto globalInst : irModule->getModuleInst()->getChildren()) - { - if (target == CodeGenTarget::DXILAssembly || target == CodeGenTarget::DXIL) - { - if (auto inst = as<IREmbeddedDownstreamIR>(globalInst)) - { - if (inst->getTarget() == CodeGenTarget::DXIL) - { - auto slice = inst->getBlob()->getStringSlice(); - ArtifactDesc desc = - ArtifactDescUtil::makeDescForCompileTarget(SLANG_DXIL); - desc.kind = ArtifactKind::Library; - - auto library = ArtifactUtil::createArtifact(desc); - - library->addRepresentationUnknown(StringBlob::create(slice)); - libraries.add(library); - } - } - } - } - }); - - options.compilerSpecificArguments = allocator.allocate(compilerSpecificArguments); - options.requiredCapabilityVersions = SliceUtil::asSlice(requiredCapabilityVersions); - options.libraries = SliceUtil::asSlice(libraries); - options.libraryPaths = allocator.allocate(libraryPaths); - - if (m_targetProfile.getFamily() == ProfileFamily::DX) - { - options.enablePAQ = m_targetProfile.getVersion() >= ProfileVersion::DX_6_7; - } - - // Compile - ComPtr<IArtifact> artifact; - auto downstreamStartTime = std::chrono::high_resolution_clock::now(); - SLANG_RETURN_ON_FAIL(compiler->compile(options, artifact.writeRef())); - auto downstreamElapsedTime = - (std::chrono::high_resolution_clock::now() - downstreamStartTime).count() * 0.000000001; - getSession()->addDownstreamCompileTime(downstreamElapsedTime); - - SLANG_RETURN_ON_FAIL(passthroughDownstreamDiagnostics(getSink(), compiler, artifact)); - - // Copy over all of the information associated with the source into the output - if (sourceArtifact) - { - for (auto associatedArtifact : sourceArtifact->getAssociated()) - { - artifact->addAssociated(associatedArtifact); - } - } - - // Set the artifact - outArtifact.swap(artifact); - return SLANG_OK; -} - -SlangResult emitSPIRVForEntryPointsDirectly( - CodeGenContext* codeGenContext, - ComPtr<IArtifact>& outArtifact); - -SlangResult emitHostVMCode(CodeGenContext* codeGenContext, ComPtr<IArtifact>& outArtifact); - -static CodeGenTarget _getIntermediateTarget(CodeGenTarget target) -{ - switch (target) - { - case CodeGenTarget::DXBytecodeAssembly: - return CodeGenTarget::DXBytecode; - case CodeGenTarget::DXILAssembly: - return CodeGenTarget::DXIL; - case CodeGenTarget::SPIRVAssembly: - return CodeGenTarget::SPIRV; - case CodeGenTarget::WGSLSPIRVAssembly: - return CodeGenTarget::WGSLSPIRV; - default: - return CodeGenTarget::None; - } -} - -static IArtifact* _getSeparateDbgArtifact(IArtifact* artifact) -{ - if (!artifact) - return nullptr; - - // The first associated artifact of kind ObjectCode and SPIRV payload should be the debug - // artifact. - for (auto* associated : artifact->getAssociated()) - { - auto desc = associated->getDesc(); - if (desc.kind == ArtifactKind::ObjectCode && desc.payload == ArtifactPayload::SPIRV) - return associated; - } - - return nullptr; -} - -/// Function to simplify the logic around emitting, and dissassembling -SlangResult CodeGenContext::_emitEntryPoints(ComPtr<IArtifact>& outArtifact) -{ - auto target = getTargetFormat(); - switch (target) - { - case CodeGenTarget::SPIRVAssembly: - case CodeGenTarget::DXBytecodeAssembly: - case CodeGenTarget::DXILAssembly: - case CodeGenTarget::MetalLibAssembly: - case CodeGenTarget::WGSLSPIRVAssembly: - { - // First compile to an intermediate target for the corresponding binary format. - const CodeGenTarget intermediateTarget = _getIntermediateTarget(target); - CodeGenContext intermediateContext(this, intermediateTarget); - - ComPtr<IArtifact> intermediateArtifact; - - SLANG_RETURN_ON_FAIL(intermediateContext._emitEntryPoints(intermediateArtifact)); - intermediateContext.maybeDumpIntermediate(intermediateArtifact); - - // Then disassemble the intermediate binary result to get the desired output - // Output the disassemble - ComPtr<IArtifact> disassemblyArtifact; - SLANG_RETURN_ON_FAIL(ArtifactOutputUtil::dissassembleWithDownstream( - getSession(), - intermediateArtifact, - getSink(), - disassemblyArtifact.writeRef())); - - // Also disassemble the debug artifact if one exists. - auto debugArtifact = _getSeparateDbgArtifact(intermediateArtifact); - ComPtr<IArtifact> disassemblyDebugArtifact; - if (debugArtifact) - { - SLANG_RETURN_ON_FAIL(ArtifactOutputUtil::dissassembleWithDownstream( - getSession(), - debugArtifact, - getSink(), - disassemblyDebugArtifact.writeRef())); - disassemblyDebugArtifact->setName(debugArtifact->getName()); - - // The disassembly needs both the metadata for the debug build identifier - // and the debug spirv to be associated with is. - for (auto associated : intermediateArtifact->getAssociated()) - { - if (associated->getDesc().payload == ArtifactPayload::Metadata || - associated->getDesc().payload == ArtifactPayload::PostEmitMetadata) - { - disassemblyArtifact->addAssociated(associated); - break; - } - } - disassemblyArtifact->addAssociated(disassemblyDebugArtifact); - } - - outArtifact.swap(disassemblyArtifact); - return SLANG_OK; - } - case CodeGenTarget::SPIRV: - if (getTargetProgram()->getOptionSet().shouldEmitSPIRVDirectly()) - { - SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPointsDirectly(this, outArtifact)); - return SLANG_OK; - } - [[fallthrough]]; - case CodeGenTarget::DXIL: - case CodeGenTarget::DXBytecode: - case CodeGenTarget::MetalLib: - case CodeGenTarget::PTX: - case CodeGenTarget::ShaderHostCallable: - case CodeGenTarget::ShaderSharedLibrary: - case CodeGenTarget::HostExecutable: - case CodeGenTarget::HostHostCallable: - case CodeGenTarget::HostSharedLibrary: - case CodeGenTarget::WGSLSPIRV: - SLANG_RETURN_ON_FAIL(emitWithDownstreamForEntryPoints(outArtifact)); - return SLANG_OK; - case CodeGenTarget::HostVM: - SLANG_RETURN_ON_FAIL(emitHostVMCode(this, outArtifact)); - return SLANG_OK; - default: - break; - } - - return SLANG_FAIL; -} - -// Helper class for recording compile time. -struct CompileTimerRAII -{ - std::chrono::high_resolution_clock::time_point startTime; - Session* session; - CompileTimerRAII(Session* inSession) - { - startTime = std::chrono::high_resolution_clock::now(); - session = inSession; - } - ~CompileTimerRAII() - { - double elapsedTime = std::chrono::duration_cast<std::chrono::microseconds>( - std::chrono::high_resolution_clock::now() - startTime) - .count() / - 1e6; - session->addTotalCompileTime(elapsedTime); - } -}; - -// Do emit logic for a zero or more entry points -SlangResult CodeGenContext::emitEntryPoints(ComPtr<IArtifact>& outArtifact) -{ - CompileTimerRAII recordCompileTime(getSession()); - - auto target = getTargetFormat(); - - switch (target) - { - case CodeGenTarget::SPIRVAssembly: - case CodeGenTarget::DXBytecodeAssembly: - case CodeGenTarget::DXILAssembly: - case CodeGenTarget::SPIRV: - case CodeGenTarget::DXIL: - case CodeGenTarget::DXBytecode: - case CodeGenTarget::MetalLib: - case CodeGenTarget::MetalLibAssembly: - case CodeGenTarget::PTX: - case CodeGenTarget::HostHostCallable: - case CodeGenTarget::ShaderHostCallable: - case CodeGenTarget::ShaderSharedLibrary: - case CodeGenTarget::HostExecutable: - case CodeGenTarget::HostSharedLibrary: - case CodeGenTarget::WGSLSPIRVAssembly: - case CodeGenTarget::HostVM: - { - SLANG_RETURN_ON_FAIL(_emitEntryPoints(outArtifact)); - - maybeDumpIntermediate(outArtifact); - return SLANG_OK; - } - break; - case CodeGenTarget::GLSL: - case CodeGenTarget::HLSL: - case CodeGenTarget::CUDASource: - case CodeGenTarget::CPPSource: - case CodeGenTarget::HostCPPSource: - case CodeGenTarget::PyTorchCppBinding: - case CodeGenTarget::CSource: - case CodeGenTarget::Metal: - case CodeGenTarget::WGSL: - { - RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target); - - CodeGenContext subContext(this, target, extensionTracker); - - ComPtr<IArtifact> sourceArtifact; - - SLANG_RETURN_ON_FAIL(subContext.emitEntryPointsSource(sourceArtifact)); - - subContext.maybeDumpIntermediate(sourceArtifact); - outArtifact = sourceArtifact; - return SLANG_OK; - } - break; - - case CodeGenTarget::None: - // The user requested no output - return SLANG_OK; - - // Note(tfoley): We currently hit this case when compiling the core module - case CodeGenTarget::Unknown: - return SLANG_OK; - - default: - SLANG_UNEXPECTED("unhandled code generation target"); - break; - } - return SLANG_FAIL; -} - -void EndToEndCompileRequest::writeArtifactToStandardOutput( - IArtifact* artifact, - DiagnosticSink* sink) -{ - // If it's host callable it's not available to write to output - if (isDerivedFrom(artifact->getDesc().kind, ArtifactKind::HostCallable)) - { - return; - } - - auto session = getSession(); - ArtifactOutputUtil::maybeConvertAndWrite( - session, - artifact, - sink, - toSlice("stdout"), - getWriter(WriterChannel::StdOutput)); -} - -String EndToEndCompileRequest::_getWholeProgramPath(TargetRequest* targetReq) -{ - RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; - if (m_targetInfos.tryGetValue(targetReq, targetInfo)) - { - return targetInfo->wholeTargetOutputPath; - } - return String(); -} - -String EndToEndCompileRequest::_getEntryPointPath(TargetRequest* targetReq, Index entryPointIndex) -{ - // It is possible that we are dynamically discovering entry - // points (using `[shader(...)]` attributes), so that there - // might be entry points added to the program that did not - // get paths specified via command-line options. - // - RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; - if (m_targetInfos.tryGetValue(targetReq, targetInfo)) - { - String outputPath; - if (targetInfo->entryPointOutputPaths.tryGetValue(entryPointIndex, outputPath)) - { - return outputPath; - } - } - - return String(); -} - -SlangResult EndToEndCompileRequest::_writeArtifact(const String& path, IArtifact* artifact) -{ - if (path.getLength() > 0) - { - SLANG_RETURN_ON_FAIL(ArtifactOutputUtil::writeToFile(artifact, getSink(), path)); - } - else if (m_containerFormat == ContainerFormat::None) - { - // If we aren't writing to a container and we didn't write to a file, we can output to - // standard output - writeArtifactToStandardOutput(artifact, getSink()); - } - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::_maybeWriteArtifact(const String& path, IArtifact* artifact) -{ - // We don't have to do anything if there is no artifact - if (!artifact) - { - return SLANG_OK; - } - - // If embedding is enabled... - if (m_sourceEmbedStyle != SourceEmbedUtil::Style::None) - { - SourceEmbedUtil::Options options; - - options.style = m_sourceEmbedStyle; - options.variableName = m_sourceEmbedName; - options.language = (SlangSourceLanguage)m_sourceEmbedLanguage; - - ComPtr<IArtifact> embeddedArtifact; - SLANG_RETURN_ON_FAIL(SourceEmbedUtil::createEmbedded(artifact, options, embeddedArtifact)); - - if (!embeddedArtifact) - { - return SLANG_FAIL; - } - SLANG_RETURN_ON_FAIL( - _writeArtifact(SourceEmbedUtil::getPath(path, options), embeddedArtifact)); - return SLANG_OK; - } - else - { - SLANG_RETURN_ON_FAIL(_writeArtifact(path, artifact)); - } - - return SLANG_OK; -} - -// These helper functions are used by the -separate-debug-info command line -// arg to extract the associated artifact containing the debug SPIRV data -// and save it to a file with a .dbg.spv extension. -static String _getDebugSpvPath(const String& basePath) -{ - // Find the last occurrence of ".spv" at the end of the string. - static const char ext[] = ".spv"; - static const char dbgExt[] = ".dbg.spv"; - Index extLen = 4; - if (basePath.getLength() >= extLen && basePath.endsWith(ext)) - { - // Replace the ".spv" extension with ".dbg.spv" - String prefix = String(basePath.subString(0, basePath.getLength() - extLen)); - return prefix + dbgExt; - } - // If it doesn't end with .spv, just append .dbg.spv - return basePath + dbgExt; -} - -SlangResult EndToEndCompileRequest::_maybeWriteDebugArtifact( - TargetProgram* targetProgram, - const String& path, - IArtifact* artifact) -{ - if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::EmitSeparateDebug)) - { - const auto dbgArtifact = _getSeparateDbgArtifact(artifact); - // Check if a debug artifact was actually created (only for SPIR-V targets) - if (dbgArtifact) - { - // The artifact's name may have been set to the debug build id hash, use - // it as the filename if it exists. - String dbgPath = dbgArtifact->getName(); - if (dbgPath.getLength() == 0) - dbgPath = _getDebugSpvPath(path); - else - dbgPath.append(".dbg.spv"); - return _maybeWriteArtifact(dbgPath, dbgArtifact); - } - // If no debug artifact exists (e.g., for non-SPIR-V targets), just silently succeed - // The warning about unsupported targets is already issued during option parsing - } - - return SLANG_OK; -} - -IArtifact* TargetProgram::_createWholeProgramResult( - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq) -{ - // We want to call `emitEntryPoints` function to generate code that contains - // all the entrypoints defined in `m_program`. - // The current logic of `emitEntryPoints` takes a list of entry-point indices to - // emit code for, so we construct such a list first. - List<Int> entryPointIndices; - - m_entryPointResults.setCount(m_program->getEntryPointCount()); - entryPointIndices.setCount(m_program->getEntryPointCount()); - for (Index i = 0; i < entryPointIndices.getCount(); i++) - entryPointIndices[i] = i; - - CodeGenContext::Shared sharedCodeGenContext(this, entryPointIndices, sink, endToEndReq); - CodeGenContext codeGenContext(&sharedCodeGenContext); - - if (SLANG_FAILED(codeGenContext.emitEntryPoints(m_wholeProgramResult))) - { - return nullptr; - } - - return m_wholeProgramResult; -} - -IArtifact* TargetProgram::_createEntryPointResult( - Int entryPointIndex, - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq) -{ - // It is possible that entry points got added to the `Program` - // *after* we created this `TargetProgram`, so there might be - // a request for an entry point that we didn't allocate space for. - // - // TODO: Change the construction logic so that a `Program` is - // constructed all at once rather than incrementally, to avoid - // this problem. - // - if (entryPointIndex >= m_entryPointResults.getCount()) - m_entryPointResults.setCount(entryPointIndex + 1); - - - CodeGenContext::EntryPointIndices entryPointIndices; - entryPointIndices.add(entryPointIndex); - - CodeGenContext::Shared sharedCodeGenContext(this, entryPointIndices, sink, endToEndReq); - CodeGenContext codeGenContext(&sharedCodeGenContext); - - codeGenContext.emitEntryPoints(m_entryPointResults[entryPointIndex]); - - return m_entryPointResults[entryPointIndex]; -} - -IArtifact* TargetProgram::getOrCreateWholeProgramResult(DiagnosticSink* sink) -{ - if (m_wholeProgramResult) - return m_wholeProgramResult; - - // If we haven't yet computed a layout for this target - // program, we need to make sure that is done before - // code generation. - // - if (!getOrCreateIRModuleForLayout(sink)) - { - return nullptr; - } - - return _createWholeProgramResult(sink); -} - -IArtifact* TargetProgram::getOrCreateEntryPointResult(Int entryPointIndex, DiagnosticSink* sink) -{ - if (entryPointIndex >= m_entryPointResults.getCount()) - m_entryPointResults.setCount(entryPointIndex + 1); - - if (IArtifact* artifact = m_entryPointResults[entryPointIndex]) - return artifact; - - // If we haven't yet computed a layout for this target - // program, we need to make sure that is done before - // code generation. - // - if (!getOrCreateIRModuleForLayout(sink)) - { - return nullptr; - } - - return _createEntryPointResult(entryPointIndex, sink); -} - -void EndToEndCompileRequest::generateOutput(TargetProgram* targetProgram) -{ - auto program = targetProgram->getProgram(); - - // Generate target code any entry points that - // have been requested for compilation. - auto entryPointCount = program->getEntryPointCount(); - if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram)) - { - targetProgram->_createWholeProgramResult(getSink(), this); - } - else - { - for (Index ii = 0; ii < entryPointCount; ++ii) - { - targetProgram->_createEntryPointResult(ii, getSink(), this); - } - } -} - - -bool _shouldWriteSourceLocs(Linkage* linkage) -{ - // If debug information or source manager are not avaiable we can't/shouldn't write out locs - if (linkage->m_optionSet.getEnumOption<DebugInfoLevel>(CompilerOptionName::DebugInformation) == - DebugInfoLevel::None || - linkage->getSourceManager() == nullptr) - { - return false; - } - - // Otherwise we do want to write out the locs - return true; -} - -SlangResult EndToEndCompileRequest::writeContainerToStream(Stream* stream) -{ - auto linkage = getLinkage(); - - // Set up options - SerialContainerUtil::WriteOptions options; - - // If debug information is enabled, enable writing out source locs - if (_shouldWriteSourceLocs(linkage)) - { - options.sourceManagerToUseWhenSerializingSourceLocs = linkage->getSourceManager(); - } - - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(this, options, stream)); - - return SLANG_OK; -} - -static IBoxValue<SourceMap>* _getObfuscatedSourceMap(TranslationUnitRequest* translationUnit) -{ - if (auto module = translationUnit->getModule()) - { - if (auto irModule = module->getIRModule()) - { - return irModule->getObfuscatedSourceMap(); - } - } - return nullptr; -} - -SlangResult EndToEndCompileRequest::maybeCreateContainer() -{ - m_containerArtifact.setNull(); - - List<ComPtr<IArtifact>> artifacts; - - auto linkage = getLinkage(); - - auto program = getSpecializedGlobalAndEntryPointsComponentType(); - - for (auto targetReq : linkage->targets) - { - auto targetProgram = program->getTargetProgram(targetReq); - - if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram)) - { - if (auto artifact = targetProgram->getExistingWholeProgramResult()) - { - if (!targetProgram->getOptionSet().getBoolOption( - CompilerOptionName::EmbedDownstreamIR)) - { - artifacts.add(ComPtr<IArtifact>(artifact)); - } - } - } - else - { - Index entryPointCount = program->getEntryPointCount(); - for (Index ee = 0; ee < entryPointCount; ++ee) - { - if (auto artifact = targetProgram->getExistingEntryPointResult(ee)) - { - artifacts.add(ComPtr<IArtifact>(artifact)); - } - } - } - } - - // If IR emitting is enabled, add IR to the artifacts - if (m_emitIr && (m_containerFormat == ContainerFormat::SlangModule)) - { - OwnedMemoryStream stream(FileAccess::Write); - SlangResult res = writeContainerToStream(&stream); - if (SLANG_FAILED(res)) - { - getSink()->diagnose(SourceLoc(), Diagnostics::unableToCreateModuleContainer); - return res; - } - - // Need to turn into a blob - List<uint8_t> blobData; - stream.swapContents(blobData); - - auto containerBlob = ListBlob::moveCreate(blobData); - - auto irArtifact = Artifact::create(ArtifactDesc::make( - Artifact::Kind::CompileBinary, - ArtifactPayload::SlangIR, - ArtifactStyle::Unknown)); - irArtifact->addRepresentationUnknown(containerBlob); - - // Add the IR artifact - artifacts.add(irArtifact); - } - - // If there is only one artifact we can use that as the container - if (artifacts.getCount() == 1) - { - m_containerArtifact = artifacts[0]; - } - else - { - m_containerArtifact = ArtifactUtil::createArtifact( - ArtifactDesc::make(ArtifactKind::Container, ArtifactPayload::CompileResults)); - - for (IArtifact* childArtifact : artifacts) - { - m_containerArtifact->addChild(childArtifact); - } - } - - // Get all of the source obfuscated source maps and add those - if (m_containerArtifact) - { - auto frontEndReq = getFrontEndReq(); - - for (auto translationUnit : frontEndReq->translationUnits) - { - // Hmmm do I have to therefore add a map for all translation units(!) - // I guess this is okay in so far as an association can always be looked up by name - if (auto sourceMap = _getObfuscatedSourceMap(translationUnit)) - { - auto artifactDesc = ArtifactDesc::make( - ArtifactKind::Json, - ArtifactPayload::SourceMap, - ArtifactStyle::Obfuscated); - - // Create the source map artifact - auto sourceMapArtifact = - Artifact::create(artifactDesc, sourceMap->get().m_file.getUnownedSlice()); - - // Add the repesentation - sourceMapArtifact->addRepresentation(sourceMap); - - // Associate with the container - m_containerArtifact->addAssociated(sourceMapArtifact); - } - } - } - - return SLANG_OK; -} - -CompilerOptionSet& EndToEndCompileRequest::getTargetOptionSet(TargetRequest* req) -{ - return req->getOptionSet(); -} - -CompilerOptionSet& EndToEndCompileRequest::getTargetOptionSet(Index targetIndex) -{ - return m_linkage->targets[targetIndex]->getOptionSet(); -} - -SlangResult EndToEndCompileRequest::maybeWriteContainer(const String& fileName) -{ - // If there is no container, or filename, don't write anything - if (fileName.getLength() == 0 || !m_containerArtifact) - { - return SLANG_OK; - } - - // Filter the containerArtifact into things that can be written - ComPtr<IArtifact> writeArtifact; - SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::filter(m_containerArtifact, writeArtifact)); - - // Only write if there is something to write - if (writeArtifact) - { - SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::writeContainer(writeArtifact, fileName)); - } - - return SLANG_OK; -} - -static void _writeString(Stream& stream, const char* string) -{ - stream.write(string, strlen(string)); -} - -static void _escapeDependencyString(const char* string, StringBuilder& outBuilder) -{ - // make has unusual escaping rules, but we only care about characters that are acceptable in a - // path - for (const char* p = string; *p; ++p) - { - char c = *p; - switch (c) - { - case ' ': - case ':': - case '#': - case '[': - case ']': - case '\\': - outBuilder.appendChar('\\'); - break; - - case '$': - outBuilder.appendChar('$'); - break; - } - - outBuilder.appendChar(c); - } -} - -// Writes a line to the file stream, formatted like this: -// <output-file>: <dependency-file> <dependency-file...> -static void _writeDependencyStatement( - Stream& stream, - EndToEndCompileRequest* compileRequest, - const String& outputPath) -{ - if (outputPath.getLength() == 0) - return; - - StringBuilder builder; - _escapeDependencyString(outputPath.begin(), builder); - _writeString(stream, builder.begin()); - _writeString(stream, ": "); - - int dependencyCount = compileRequest->getDependencyFileCount(); - for (int dependencyIndex = 0; dependencyIndex < dependencyCount; ++dependencyIndex) - { - builder.clear(); - _escapeDependencyString(compileRequest->getDependencyFilePath(dependencyIndex), builder); - _writeString(stream, builder.begin()); - _writeString(stream, (dependencyIndex + 1 < dependencyCount) ? " " : "\n"); - } -} - -// Writes a file with dependency info, with one line in the output file per compile product. -static SlangResult _writeDependencyFile(EndToEndCompileRequest* compileRequest) -{ - if (compileRequest->m_dependencyOutputPath.getLength() == 0) - return SLANG_OK; - - FileStream stream; - SLANG_RETURN_ON_FAIL(stream.init( - compileRequest->m_dependencyOutputPath, - FileMode::Create, - FileAccess::Write, - FileShare::ReadWrite)); - - auto linkage = compileRequest->getLinkage(); - auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); - - // Iterate over all the targets and their outputs - for (const auto& targetReq : linkage->targets) - { - if (compileRequest->getTargetOptionSet(targetReq).getBoolOption( - CompilerOptionName::GenerateWholeProgram)) - { - RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; - if (compileRequest->m_targetInfos.tryGetValue(targetReq, targetInfo)) - { - _writeDependencyStatement( - stream, - compileRequest, - targetInfo->wholeTargetOutputPath); - } - } - else - { - Index entryPointCount = program->getEntryPointCount(); - for (Index entryPointIndex = 0; entryPointIndex < entryPointCount; ++entryPointIndex) - { - RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; - if (compileRequest->m_targetInfos.tryGetValue(targetReq, targetInfo)) - { - String outputPath; - if (targetInfo->entryPointOutputPaths.tryGetValue(entryPointIndex, outputPath)) - { - _writeDependencyStatement(stream, compileRequest, outputPath); - } - } - } - } - } - - // When the output is a binary module, linkage->targets can be empty. So - // we need to do their dependencies separately. - if (compileRequest->m_containerFormat == ContainerFormat::SlangModule) - { - _writeDependencyStatement(stream, compileRequest, compileRequest->m_containerOutputPath); - } - - return SLANG_OK; -} - - -void EndToEndCompileRequest::generateOutput(ComponentType* program) -{ - // When dynamic dispatch is disabled, the program must - // be fully specialized by now. So we check if we still - // have unspecialized generic/existential parameters, - // and report them as an error. - // - auto specializationParamCount = program->getSpecializationParamCount(); - if (getOptionSet().getBoolOption(CompilerOptionName::DisableDynamicDispatch) && - specializationParamCount != 0) - { - auto sink = getSink(); - - for (Index ii = 0; ii < specializationParamCount; ++ii) - { - auto specializationParam = program->getSpecializationParam(ii); - if (auto decl = as<Decl>(specializationParam.object)) - { - sink->diagnose( - specializationParam.loc, - Diagnostics::specializationParameterOfNameNotSpecialized, - decl); - } - else if (auto type = as<Type>(specializationParam.object)) - { - sink->diagnose( - specializationParam.loc, - Diagnostics::specializationParameterOfNameNotSpecialized, - type); - } - else - { - sink->diagnose( - specializationParam.loc, - Diagnostics::specializationParameterNotSpecialized); - } - } - - return; - } - - - // Go through the code-generation targets that the user - // has specified, and generate code for each of them. - // - auto linkage = getLinkage(); - for (auto targetReq : linkage->targets) - { - if (targetReq->getOptionSet().getBoolOption(CompilerOptionName::EmbedDownstreamIR)) - continue; - - auto targetProgram = program->getTargetProgram(targetReq); - generateOutput(targetProgram); - } -} - -void EndToEndCompileRequest::generateOutput() -{ - SLANG_PROFILE; - generateOutput(getSpecializedGlobalAndEntryPointsComponentType()); - - // If we are in command-line mode, we might be expected to actually - // write output to one or more files here. - - if (m_isCommandLineCompile && m_containerFormat == ContainerFormat::None) - { - auto linkage = getLinkage(); - auto program = getSpecializedGlobalAndEntryPointsComponentType(); - - for (auto targetReq : linkage->targets) - { - auto targetProgram = program->getTargetProgram(targetReq); - - if (targetProgram->getOptionSet().getBoolOption( - CompilerOptionName::GenerateWholeProgram)) - { - if (const auto artifact = targetProgram->getExistingWholeProgramResult()) - { - const auto path = _getWholeProgramPath(targetReq); - - _maybeWriteArtifact(path, artifact); - - // If we are compiling separate debug info, check for the additional - // SPIRV artifact and write that if needed. - _maybeWriteDebugArtifact(targetProgram, path, artifact); - } - } - else - { - Index entryPointCount = program->getEntryPointCount(); - for (Index ee = 0; ee < entryPointCount; ++ee) - { - if (const auto artifact = targetProgram->getExistingEntryPointResult(ee)) - { - const auto path = _getEntryPointPath(targetReq, ee); - - _maybeWriteArtifact(path, artifact); - - // If we are compiling separate debug info, check for the additional - // SPIRV artifact and write that if needed. - _maybeWriteDebugArtifact(targetProgram, path, artifact); - } - } - } - } - } - - // Maybe create the container - maybeCreateContainer(); - - // If it's a command line compile we may need to write the container to a file - if (m_isCommandLineCompile) - { - // TODO(JS): - // We could write the container into a source embedded format potentially - - maybeWriteContainer(m_containerOutputPath); - - _writeDependencyFile(this); - } -} - -// Debug logic for dumping intermediate outputs - - -void CodeGenContext::_dumpIntermediateMaybeWithAssembly(IArtifact* artifact) -{ - _dumpIntermediate(artifact); - - ComPtr<IArtifact> assembly; - ArtifactOutputUtil::maybeDisassemble(getSession(), artifact, nullptr, assembly); - - if (assembly) - { - _dumpIntermediate(assembly); - } -} - -void CodeGenContext::_dumpIntermediate(IArtifact* artifact) -{ - ComPtr<ISlangBlob> blob; - if (SLANG_FAILED(artifact->loadBlob(ArtifactKeep::No, blob.writeRef()))) - { - return; - } - _dumpIntermediate(artifact->getDesc(), blob->getBufferPointer(), blob->getBufferSize()); -} - -void CodeGenContext::_dumpIntermediate(const ArtifactDesc& desc, void const* data, size_t size) -{ - // Try to generate a unique ID for the file to dump, - // even in cases where there might be multiple threads - // doing compilation. - // - // This is primarily a debugging aid, so we don't - // really need/want to do anything too elaborate - - static std::atomic<uint32_t> counter(0); - - const uint32_t id = ++counter; - - // Just use the counter for the 'base name' - StringBuilder basename; - - // Add the prefix - basename << getIntermediateDumpPrefix(); - - // Add the id - basename << int(id); - - // Work out the filename based on the desc and the basename - StringBuilder filename; - ArtifactDescUtil::calcNameForDesc(desc, basename.getUnownedSlice(), filename); - - // If didn't produce a filename, use basename with .unknown extension - if (filename.getLength() == 0) - { - filename = basename; - filename << ".unknown"; - } - - // Write to a file - ArtifactOutputUtil::writeToFile(desc, data, size, filename); -} - -void CodeGenContext::maybeDumpIntermediate(IArtifact* artifact) -{ - if (!shouldDumpIntermediates()) - return; - - - _dumpIntermediateMaybeWithAssembly(artifact); -} - -IRDumpOptions CodeGenContext::getIRDumpOptions() -{ - if (auto endToEndReq = isEndToEndCompile()) - { - return endToEndReq->getFrontEndReq()->m_irDumpOptions; - } - return IRDumpOptions(); -} - -bool CodeGenContext::shouldValidateIR() -{ - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ValidateIr); -} - -bool CodeGenContext::shouldSkipSPIRVValidation() -{ - return getTargetProgram()->getOptionSet().getBoolOption( - CompilerOptionName::SkipSPIRVValidation); -} - -bool CodeGenContext::shouldDumpIR() -{ - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr); -} - -bool CodeGenContext::shouldSkipDownstreamLinking() -{ - return getTargetProgram()->getOptionSet().getBoolOption( - CompilerOptionName::SkipDownstreamLinking); -} - -bool CodeGenContext::shouldReportCheckpointIntermediates() -{ - return getTargetProgram()->getOptionSet().getBoolOption( - CompilerOptionName::ReportCheckpointIntermediates); -} - -bool CodeGenContext::shouldDumpIntermediates() -{ - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIntermediates); -} - -bool CodeGenContext::shouldTrackLiveness() -{ - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::TrackLiveness); -} - -String CodeGenContext::getIntermediateDumpPrefix() -{ - return getTargetProgram()->getOptionSet().getStringOption( - CompilerOptionName::DumpIntermediatePrefix); -} - -bool CodeGenContext::getUseUnknownImageFormatAsDefault() -{ - return getTargetProgram()->getOptionSet().getBoolOption( - CompilerOptionName::DefaultImageFormatUnknown); -} - -bool CodeGenContext::isSpecializationDisabled() -{ - return getTargetProgram()->getOptionSet().getBoolOption( - CompilerOptionName::DisableSpecialization); -} - -SLANG_NO_THROW SlangResult SLANG_MCALL Module::serialize(ISlangBlob** outSerializedBlob) -{ - SerialContainerUtil::WriteOptions writeOptions; - OwnedMemoryStream memoryStream(FileAccess::Write); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(this, writeOptions, &memoryStream)); - *outSerializedBlob = RawBlob::create( - memoryStream.getContents().getBuffer(), - (size_t)memoryStream.getContents().getCount()) - .detach(); - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL Module::writeToFile(char const* fileName) -{ - SerialContainerUtil::WriteOptions writeOptions; - FileStream fileStream; - SLANG_RETURN_ON_FAIL(fileStream.init(fileName, FileMode::Create)); - return SerialContainerUtil::write(this, writeOptions, &fileStream); -} - -SLANG_NO_THROW const char* SLANG_MCALL Module::getName() -{ - if (m_name) - return m_name->text.getBuffer(); - return nullptr; -} - -SLANG_NO_THROW const char* SLANG_MCALL Module::getFilePath() -{ - if (m_pathInfo.hasFoundPath()) - return m_pathInfo.foundPath.getBuffer(); - return nullptr; -} - -SLANG_NO_THROW const char* SLANG_MCALL Module::getUniqueIdentity() -{ - if (m_pathInfo.hasUniqueIdentity()) - return m_pathInfo.getMostUniqueIdentity().getBuffer(); - return nullptr; -} - -SLANG_NO_THROW SlangInt32 SLANG_MCALL Module::getDependencyFileCount() -{ - return (SlangInt32)getFileDependencies().getCount(); -} - -SLANG_NO_THROW char const* SLANG_MCALL Module::getDependencyFilePath(SlangInt32 index) -{ - SourceFile* sourceFile = getFileDependencies()[index]; - return sourceFile->getPathInfo().hasFoundPath() - ? sourceFile->getPathInfo().getMostUniqueIdentity().getBuffer() - : nullptr; -} - -void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink); - -void Module::_discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets) -{ - if (m_entryPoints.getCount() > 0) - return; - _discoverEntryPointsImpl(m_moduleDecl, sink, targets); -} -void Module::_discoverEntryPointsImpl( - ContainerDecl* containerDecl, - DiagnosticSink* sink, - const List<RefPtr<TargetRequest>>& targets) -{ - for (auto globalDecl : containerDecl->getDirectMemberDecls()) - { - auto maybeFuncDecl = globalDecl; - if (auto genericDecl = as<GenericDecl>(maybeFuncDecl)) - { - maybeFuncDecl = genericDecl->inner; - } - - if (as<NamespaceDeclBase>(globalDecl) || as<FileDecl>(globalDecl) || - as<StructDecl>(globalDecl)) - { - _discoverEntryPointsImpl(as<ContainerDecl>(globalDecl), sink, targets); - continue; - } - - auto funcDecl = as<FuncDecl>(maybeFuncDecl); - if (!funcDecl) - continue; - - Profile profile; - bool resolvedStageOfProfileWithEntryPoint = resolveStageOfProfileWithEntryPoint( - profile, - getLinkage()->m_optionSet, - targets, - funcDecl, - sink); - if (!resolvedStageOfProfileWithEntryPoint) - { - // If there isn't a [shader] attribute, look for a [numthreads] attribute - // since that implicitly means a compute shader. We'll not do this when compiling for - // CUDA/Torch since [numthreads] attributes are utilized differently for those targets. - // - - bool allTargetsCUDARelated = true; - for (auto target : targets) - { - if (!isCUDATarget(target) && - target->getTarget() != CodeGenTarget::PyTorchCppBinding) - { - allTargetsCUDARelated = false; - break; - } - } - - if (allTargetsCUDARelated && targets.getCount() > 0) - continue; - - bool canDetermineStage = false; - for (auto modifier : funcDecl->modifiers) - { - if (as<NumThreadsAttribute>(modifier)) - { - if (funcDecl->findModifier<OutputTopologyAttribute>()) - profile.setStage(Stage::Mesh); - else - profile.setStage(Stage::Compute); - canDetermineStage = true; - break; - } - else if (as<PatchConstantFuncAttribute>(modifier)) - { - profile.setStage(Stage::Hull); - canDetermineStage = true; - break; - } - } - if (!canDetermineStage) - continue; - } - - RefPtr<EntryPoint> entryPoint = - EntryPoint::create(getLinkage(), makeDeclRef(funcDecl), profile); - - validateEntryPoint(entryPoint, sink); - - // Note: in the case that the user didn't explicitly - // specify entry points and we are instead compiling - // a shader "library," then we do not want to automatically - // combine the entry points into groups in the generated - // `Program`, since that would be slightly too magical. - // - // Instead, each entry point will end up in a singleton - // group, so that its entry-point parameters lay out - // independent of the others. - // - _addEntryPoint(entryPoint); - } -} } // namespace Slang diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index d45e796d9..934f86096 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1,6 +1,23 @@ +// slang-compiler.h #ifndef SLANG_COMPILER_H_INCLUDED #define SLANG_COMPILER_H_INCLUDED +// +// This file provides an umbrella header that ties together +// the headers for a bunch of the core types used by the +// Slang compiler implementation: the global session, session, +// modules, entry points, etc. +// +// Note: this file used to be a kind of kitchen-sink header +// with thousands of lines of declarations, and even though +// those declarations have migrated to their own files, this +// header has been otherwise left as-is to avoid breaking +// all of the code that `#include`s it. +// +// Please avoid adding new declarations in here without a clear +// motivation for *why* they belong here. +// + #include "../compiler-core/slang-artifact-representation-impl.h" #include "../compiler-core/slang-command-line-args.h" #include "../compiler-core/slang-downstream-compiler-util.h" @@ -15,16 +32,29 @@ #include "../core/slang-file-system.h" #include "../core/slang-shared-library.h" #include "../core/slang-std-writers.h" +#include "slang-base-type-info.h" #include "slang-capability.h" +#include "slang-code-gen.h" #include "slang-com-ptr.h" +#include "slang-compile-request.h" +#include "slang-compiler-api.h" #include "slang-compiler-options.h" #include "slang-content-assist-info.h" #include "slang-diagnostics.h" +#include "slang-end-to-end-request.h" +#include "slang-global-session.h" #include "slang-hlsl-to-vulkan-layout-options.h" +#include "slang-linkable-impls.h" +#include "slang-linkable.h" +#include "slang-module.h" +#include "slang-pass-through.h" #include "slang-preprocessor.h" #include "slang-profile.h" #include "slang-serialize-ir-types.h" +#include "slang-session.h" #include "slang-syntax.h" +#include "slang-target.h" +#include "slang-translation-unit.h" #include "slang.h" #include <chrono> @@ -33,7 +63,6 @@ namespace Slang { struct PathInfo; struct IncludeHandler; -struct SharedSemanticsContext; struct ModuleChunk; class ProgramLayout; @@ -50,61 +79,6 @@ enum class CompilerMode GenerateChoice }; -enum class StageTarget -{ - Unknown, - VertexShader, - HullShader, - DomainShader, - GeometryShader, - FragmentShader, - ComputeShader, -}; - -enum class CodeGenTarget : SlangCompileTargetIntegral -{ - Unknown = SLANG_TARGET_UNKNOWN, - None = SLANG_TARGET_NONE, - GLSL = SLANG_GLSL, - HLSL = SLANG_HLSL, - SPIRV = SLANG_SPIRV, - SPIRVAssembly = SLANG_SPIRV_ASM, - DXBytecode = SLANG_DXBC, - DXBytecodeAssembly = SLANG_DXBC_ASM, - DXIL = SLANG_DXIL, - DXILAssembly = SLANG_DXIL_ASM, - CSource = SLANG_C_SOURCE, - CPPSource = SLANG_CPP_SOURCE, - PyTorchCppBinding = SLANG_CPP_PYTORCH_BINDING, - HostCPPSource = SLANG_HOST_CPP_SOURCE, - HostExecutable = SLANG_HOST_EXECUTABLE, - HostSharedLibrary = SLANG_HOST_SHARED_LIBRARY, - ShaderSharedLibrary = SLANG_SHADER_SHARED_LIBRARY, - ShaderHostCallable = SLANG_SHADER_HOST_CALLABLE, - CUDASource = SLANG_CUDA_SOURCE, - PTX = SLANG_PTX, - CUDAObjectCode = SLANG_CUDA_OBJECT_CODE, - ObjectCode = SLANG_OBJECT_CODE, - HostHostCallable = SLANG_HOST_HOST_CALLABLE, - Metal = SLANG_METAL, - MetalLib = SLANG_METAL_LIB, - MetalLibAssembly = SLANG_METAL_LIB_ASM, - WGSL = SLANG_WGSL, - WGSLSPIRVAssembly = SLANG_WGSL_SPIRV_ASM, - WGSLSPIRV = SLANG_WGSL_SPIRV, - HostVM = SLANG_HOST_VM, - CountOf = SLANG_TARGET_COUNT_OF, -}; - -bool isHeterogeneousTarget(CodeGenTarget target); - -void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val); - -enum class ContainerFormat : SlangContainerFormatIntegral -{ - None = SLANG_CONTAINER_FORMAT_NONE, - SlangModule = SLANG_CONTAINER_FORMAT_SLANG_MODULE, -}; enum class LineDirectiveMode : SlangLineDirectiveModeIntegral { @@ -169,1866 +143,9 @@ class Linkage; class Module; class TranslationUnitRequest; -/// Information collected about global or entry-point shader parameters -struct ShaderParamInfo -{ - DeclRef<VarDeclBase> paramDeclRef; - Int firstSpecializationParamIndex = 0; - Int specializationParamCount = 0; -}; - -/// A request for the front-end to find and validate an entry-point function -struct FrontEndEntryPointRequest : RefObject -{ -public: - /// Create a request for an entry point. - FrontEndEntryPointRequest( - FrontEndCompileRequest* compileRequest, - int translationUnitIndex, - Name* name, - Profile profile); - - /// Get the parent front-end compile request. - FrontEndCompileRequest* getCompileRequest() { return m_compileRequest; } - - /// Get the translation unit that contains the entry point. - TranslationUnitRequest* getTranslationUnit(); - - /// Get the name of the entry point to find. - Name* getName() { return m_name; } - - /// Get the stage that the entry point is to be compiled for - Stage getStage() { return m_profile.getStage(); } - - /// Get the profile that the entry point is to be compiled for - Profile getProfile() { return m_profile; } - - /// Get the index to the translation unit - int getTranslationUnitIndex() const { return m_translationUnitIndex; } - -private: - // The parent compile request - FrontEndCompileRequest* m_compileRequest; - - // The index of the translation unit that will hold the entry point - int m_translationUnitIndex; - - // The name of the entry point function to look for - Name* m_name; - - // The profile to compile for (including stage) - Profile m_profile; -}; - -/// Tracks an ordered list of modules that something depends on. -/// TODO: Shader caching currently relies on this being in well defined order. -struct ModuleDependencyList -{ -public: - /// Get the list of modules that are depended on. - List<Module*> const& getModuleList() { return m_moduleList; } - - /// Add a module and everything it depends on to the list. - void addDependency(Module* module); - - /// Add a module to the list, but not the modules it depends on. - void addLeafDependency(Module* module); - -private: - void _addDependency(Module* module); - - List<Module*> m_moduleList; - HashSet<Module*> m_moduleSet; -}; - -/// Tracks an unordered list of source files that something depends on -/// TODO: Shader caching currently relies on this being in well defined order. -struct FileDependencyList -{ -public: - /// Get the list of files that are depended on. - List<SourceFile*> const& getFileList() { return m_fileList; } - - /// Add a file to the list, if it is not already present - void addDependency(SourceFile* sourceFile); - - /// Add all of the paths that `module` depends on to the list - void addDependency(Module* module); - - void clear() - { - m_fileList.clear(); - m_fileSet.clear(); - } - -private: - // TODO: We are using a `HashSet` here to deduplicate - // the paths so that we don't return the same path - // multiple times from `getFilePathList`, but because - // order isn't important, we could potentially do better - // in terms of memory (at some cost in performance) by - // just sorting the `m_fileList` every once in - // a while and then deduplicating. - - List<SourceFile*> m_fileList; - HashSet<SourceFile*> m_fileSet; -}; - - -class EntryPoint; - -class ComponentType; -class ComponentTypeVisitor; - -/// Base class for "component types" that represent the pieces a final -/// shader program gets linked together from. -/// -class ComponentType : public RefObject, - public slang::IComponentType, - public slang::IComponentType2, - public slang::IModulePrecompileService_Experimental -{ -public: - // - // ISlangUnknown interface - // - - SLANG_REF_OBJECT_IUNKNOWN_ALL; - ISlangUnknown* getInterface(Guid const& guid); - - // - // slang::IComponentType interface - // - - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE; - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL - getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - - IArtifact* getTargetArtifact(SlangInt targetIndex, slang::IBlob** outDiagnostics); - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - - SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( - SlangInt entryPointIndex, - SlangInt targetIndex, - ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE; - - SLANG_NO_THROW SlangResult SLANG_MCALL specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL - renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL link( - slang::IComponentType** outLinkedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - - /// ComponentType is the only class inheriting from IComponentType that provides a - /// meaningful implementation for this function. All others should forward these and - /// implement `buildHash`. - SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outHash) SLANG_OVERRIDE; - - SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( - slang::IComponentType** outLinkedComponentType, - uint32_t count, - slang::CompilerOptionEntry* entries, - ISlangBlob** outDiagnostics) override; - - // - // slang::IComponentType2 interface - // - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCompileResult( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::ICompileResult** outCompileResult, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCompileResult( - SlangInt targetIndex, - slang::ICompileResult** outCompileResult, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - - // - // slang::IModulePrecompileService interface - // - SLANG_NO_THROW SlangResult SLANG_MCALL - precompileForTarget(SlangCompileTarget target, slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - - SLANG_NO_THROW SlangResult SLANG_MCALL getPrecompiledTargetCode( - SlangCompileTarget target, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - - SLANG_NO_THROW SlangInt SLANG_MCALL getModuleDependencyCount() SLANG_OVERRIDE; - - SLANG_NO_THROW SlangResult SLANG_MCALL getModuleDependency( - SlangInt dependencyIndex, - slang::IModule** outModule, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - - CompilerOptionSet& getOptionSet() { return m_optionSet; } - - /// Get the linkage (aka "session" in the public API) for this component type. - Linkage* getLinkage() { return m_linkage; } - - /// Get the target-specific version of this program for the given `target`. - /// - /// The `target` must be a target on the `Linkage` that was used to create this program. - TargetProgram* getTargetProgram(TargetRequest* target); - - /// Update the hash builder with the dependencies for this component type. - virtual void buildHash(DigestBuilder<SHA1>& builder) = 0; - - /// Get the number of entry points linked into this component type. - virtual Index getEntryPointCount() = 0; - - /// Get one of the entry points linked into this component type. - virtual RefPtr<EntryPoint> getEntryPoint(Index index) = 0; - - /// Get the mangled name of one of the entry points linked into this component type. - virtual String getEntryPointMangledName(Index index) = 0; - - /// Get the name override of one of the entry points linked into this component type. - virtual String getEntryPointNameOverride(Index index) = 0; - - /// Get the number of global shader parameters linked into this component type. - virtual Index getShaderParamCount() = 0; - - /// Get one of the global shader parametesr linked into this component type. - virtual ShaderParamInfo getShaderParam(Index index) = 0; - - /// Get the specialization parameter at `index`. - virtual SpecializationParam const& getSpecializationParam(Index index) = 0; - - /// Get the number of "requirements" that this component type has. - /// - /// A requirement represents another component type that this component - /// needs in order to function correctly. For example, the dependency - /// of one module on another module that it `import`s is represented - /// as a requirement, as is the dependency of an entry point on the - /// module that defines it. - /// - virtual Index getRequirementCount() = 0; - - /// Get the requirement at `index`. - virtual RefPtr<ComponentType> getRequirement(Index index) = 0; - - /// Parse a type from a string, in the context of this component type. - /// - /// Any names in the string will be resolved using the modules - /// referenced by the program. - /// - /// On an error, returns null and reports diagnostic messages - /// to the provided `sink`. - /// - /// TODO: This function shouldn't be on the base class, since - /// it only really makes sense on `Module`. - /// - Type* getTypeFromString(String const& typeStr, DiagnosticSink* sink); - - Expr* findDeclFromString(String const& name, DiagnosticSink* sink); - - Expr* findDeclFromStringInType( - Type* type, - String const& name, - LookupMask mask, - DiagnosticSink* sink); - - bool isSubType(Type* subType, Type* superType); - - Dictionary<String, IntVal*>& getMangledNameToIntValMap(); - ConstantIntVal* tryFoldIntVal(IntVal* intVal); - - /// Get a list of modules that this component type depends on. - /// - virtual List<Module*> const& getModuleDependencies() = 0; - - /// Get the full list of source files this component type depends on. - /// - virtual List<SourceFile*> const& getFileDependencies() = 0; - - /// Callback for use with `enumerateIRModules` - typedef void (*EnumerateIRModulesCallback)(IRModule* irModule, void* userData); - - /// Invoke `callback` on all the IR modules that are (transitively) linked into this component - /// type. - void enumerateIRModules(EnumerateIRModulesCallback callback, void* userData); - - /// Invoke `callback` on all the IR modules that are (transitively) linked into this component - /// type. - template<typename F> - void enumerateIRModules(F const& callback) - { - struct Helper - { - static void helper(IRModule* irModule, void* userData) { (*(F*)userData)(irModule); } - }; - enumerateIRModules(&Helper::helper, (void*)&callback); - } - - /// Callback for use with `enumerateModules` - typedef void (*EnumerateModulesCallback)(Module* module, void* userData); - - /// Invoke `callback` on all the modules that are (transitively) linked into this component - /// type. - void enumerateModules(EnumerateModulesCallback callback, void* userData); - - /// Invoke `callback` on all the modules that are (transitively) linked into this component - /// type. - template<typename F> - void enumerateModules(F const& callback) - { - struct Helper - { - static void helper(Module* module, void* userData) { (*(F*)userData)(module); } - }; - enumerateModules(&Helper::helper, (void*)&callback); - } - - /// Side-band information generated when specializing this component type. - /// - /// Difference subclasses of `ComponentType` are expected to create their - /// own subclass of `SpecializationInfo` as the output of `_validateSpecializationArgs`. - /// Later, whenever we want to use a specialized component type we will - /// also have the `SpecializationInfo` available and will expect it to - /// have the correct (subclass-specific) type. - /// - class SpecializationInfo : public RefObject - { - }; - - /// Validate the given specialization `args` and compute any side-band specialization info. - /// - /// Any errors will be reported to `sink`, which can thus be used to test - /// if the operation was successful. - /// - /// A null return value is allowed, since not all subclasses require - /// custom side-band specialization information. - /// - /// This function is an implementation detail of `specialize()`. - /// - virtual RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) = 0; - - /// Validate the given specialization `args` and compute any side-band specialization info. - /// - /// Any errors will be reported to `sink`, which can thus be used to test - /// if the operation was successful. - /// - /// A null return value is allowed, since not all subclasses require - /// custom side-band specialization information. - /// - /// This function is an implementation detail of `specialize()`. - /// - RefPtr<SpecializationInfo> _validateSpecializationArgs( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) - { - if (argCount == 0) - return nullptr; - return _validateSpecializationArgsImpl(args, argCount, sink); - } - - /// Specialize this component type given `specializationArgs` - /// - /// Any diagnostics will be reported to `sink`, which can be used - /// to determine if the operation was successful. It is allowed - /// for this operation to have a non-null return even when an - /// error is ecnountered. - /// - RefPtr<ComponentType> specialize( - SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - DiagnosticSink* sink); - - /// Invoke `visitor` on this component type, using the appropriate dynamic type. - /// - /// This function implements the "visitor pattern" for `ComponentType`. - /// - /// If the `specializationInfo` argument is non-null, it must be specialization - /// information generated for this specific component type by `_validateSpecializationArgs`. - /// In that case, appropriately-typed specialization information will be passed - /// when invoking the `visitor`. - /// - virtual void acceptVisitor( - ComponentTypeVisitor* visitor, - SpecializationInfo* specializationInfo) = 0; - - /// Create a scope suitable for looking up names or parsing specialization arguments. - /// - /// This facility is only needed to support legacy APIs for string-based lookup - /// and parsing via Slang reflection, and is not recommended for future APIs to use. - /// - Scope* _getOrCreateScopeForLegacyLookup(ASTBuilder* astBuilder); - -protected: - ComponentType(Linkage* linkage); - -protected: - Linkage* m_linkage; - - CompilerOptionSet m_optionSet; - - // Cache of target-specific programs for each target. - Dictionary<TargetRequest*, RefPtr<TargetProgram>> m_targetPrograms; - - // Any types looked up dynamically using `getTypeFromString` - // - // TODO: Remove this. Type lookup should only be supported on `Module`s. - // - Dictionary<String, Type*> m_types; - - // Any decls looked up dynamically using `findDeclFromString`. - Dictionary<String, Expr*> m_decls; - - Scope* m_lookupScope = nullptr; - std::unique_ptr<Dictionary<String, IntVal*>> m_mapMangledNameToIntVal; - - Dictionary<Int, ComPtr<IArtifact>> m_targetArtifacts; -}; - -/// A component type built up from other component types. -class CompositeComponentType : public ComponentType -{ -public: - static RefPtr<ComponentType> create( - Linkage* linkage, - List<RefPtr<ComponentType>> const& childComponents); - - virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE; - - List<RefPtr<ComponentType>> const& getChildComponents() { return m_childComponents; }; - Index getChildComponentCount() { return m_childComponents.getCount(); } - RefPtr<ComponentType> getChildComponent(Index index) { return m_childComponents[index]; } - - Index getEntryPointCount() SLANG_OVERRIDE; - RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE; - String getEntryPointMangledName(Index index) SLANG_OVERRIDE; - String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; - - Index getShaderParamCount() SLANG_OVERRIDE; - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE; - - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE; - SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; - - Index getRequirementCount() SLANG_OVERRIDE; - RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; - - List<Module*> const& getModuleDependencies() SLANG_OVERRIDE; - List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE; - - class CompositeSpecializationInfo : public SpecializationInfo - { - public: - List<RefPtr<SpecializationInfo>> childInfos; - }; - -protected: - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) - SLANG_OVERRIDE; - - - RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE; - -public: - CompositeComponentType(Linkage* linkage, List<RefPtr<ComponentType>> const& childComponents); - -private: - List<RefPtr<ComponentType>> m_childComponents; - - // The following arrays hold the concatenated entry points, parameters, - // etc. from the child components. This approach allows for reasonably - // fast (constant time) access through operations like `getShaderParam`, - // but means that the memory usage of a composite is proportional to - // the sum of the memory usage of the children, rather than being fixed - // by the number of children (as it would be if we just stored - // `m_childComponents`). - // - // TODO: We could conceivably build some O(numChildren) arrays that - // support binary-search to provide logarithmic-time access to entry - // points, parameters, etc. while giving a better overall memory usage. - // - List<EntryPoint*> m_entryPoints; - List<String> m_entryPointMangledNames; - List<String> m_entryPointNameOverrides; - List<ShaderParamInfo> m_shaderParams; - List<SpecializationParam> m_specializationParams; - List<ComponentType*> m_requirements; - - ModuleDependencyList m_moduleDependencyList; - FileDependencyList m_fileDependencyList; -}; - -/// A component type created by specializing another component type. -class SpecializedComponentType : public ComponentType -{ -public: - SpecializedComponentType( - ComponentType* base, - SpecializationInfo* specializationInfo, - List<SpecializationArg> const& specializationArgs, - DiagnosticSink* sink); - - virtual void buildHash(DigestBuilder<SHA1>& builer) SLANG_OVERRIDE; - - /// Get the base (unspecialized) component type that is being specialized. - RefPtr<ComponentType> getBaseComponentType() { return m_base; } - - RefPtr<SpecializationInfo> getSpecializationInfo() { return m_specializationInfo; } - - /// Get the number of arguments supplied for existential type parameters. - /// - /// Note that the number of arguments may not match the number of parameters. - /// In particular, an unspecialized entry point may have many parameters, but zero arguments. - Index getSpecializationArgCount() { return m_specializationArgs.getCount(); } - - /// Get the existential type argument (type and witness table) at `index`. - SpecializationArg const& getSpecializationArg(Index index) - { - return m_specializationArgs[index]; - } - - /// Get an array of all existential type arguments. - SpecializationArg const* getSpecializationArgs() { return m_specializationArgs.getBuffer(); } - - Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } - RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE - { - return m_base->getEntryPoint(index); - } - String getEntryPointMangledName(Index index) SLANG_OVERRIDE; - String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; - - Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE - { - return m_base->getShaderParam(index); - } - - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } - SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - static SpecializationParam dummy; - return dummy; - } - - Index getRequirementCount() SLANG_OVERRIDE; - RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; - - List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies; } - List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencies; } - - RefPtr<IRModule> getIRModule() { return m_irModule; } - - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) - SLANG_OVERRIDE; - -protected: - RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE - { - SLANG_UNUSED(args); - SLANG_UNUSED(argCount); - SLANG_UNUSED(sink); - return nullptr; - } - -private: - RefPtr<ComponentType> m_base; - RefPtr<SpecializationInfo> m_specializationInfo; - SpecializationArgs m_specializationArgs; - RefPtr<IRModule> m_irModule; - - List<String> m_entryPointMangledNames; - List<String> m_entryPointNameOverrides; - - List<Module*> m_moduleDependencies; - List<SourceFile*> m_fileDependencies; - List<RefPtr<ComponentType>> m_requirements; -}; - -class RenamedEntryPointComponentType : public ComponentType -{ -public: - using Super = ComponentType; - - RenamedEntryPointComponentType(ComponentType* base, String newName); - - ComponentType* getBase() { return m_base.Ptr(); } - - // Forward `IComponentType` methods - - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE - { - return Super::getSession(); - } - - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL - getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getLayout(targetIndex, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::specialize( - specializationArgs, - specializationArgCount, - outSpecializedComponentType, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL - renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE - { - return Super::renameEntryPoint(newName, outEntryPoint); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL - link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::link(outLinkedComponentType, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointHostCallable( - entryPointIndex, - targetIndex, - outSharedLibrary, - outDiagnostics); - } - - List<Module*> const& getModuleDependencies() SLANG_OVERRIDE - { - return m_base->getModuleDependencies(); - } - List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE - { - return m_base->getFileDependencies(); - } - - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE - { - return m_base->getSpecializationParamCount(); - } - - SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE - { - return m_base->getSpecializationParam(index); - } - - Index getRequirementCount() SLANG_OVERRIDE { return m_base->getRequirementCount(); } - RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE - { - return m_base->getRequirement(index); - } - Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } - RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE - { - return m_base->getEntryPoint(index); - } - String getEntryPointMangledName(Index index) SLANG_OVERRIDE - { - return m_base->getEntryPointMangledName(index); - } - String getEntryPointNameOverride(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - SLANG_ASSERT(index == 0); - return m_entryPointNameOverride; - } - - Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE - { - return m_base->getShaderParam(index); - } - - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) - SLANG_OVERRIDE; - - virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE; - -private: - RefPtr<ComponentType> m_base; - String m_entryPointNameOverride; - -protected: - RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE - { - return m_base->_validateSpecializationArgsImpl(args, argCount, sink); - } -}; - -/// Describes an entry point for the purposes of layout and code generation. -/// -/// This class also tracks any generic arguments to the entry point, -/// in the case that it is a specialization of a generic entry point. -/// -/// There is also a provision for creating a "dummy" entry point for -/// the purposes of pass-through compilation modes. Only the -/// `getName()` and `getProfile()` methods should be expected to -/// return useful data on pass-through entry points. -/// -class EntryPoint : public ComponentType, public slang::IEntryPoint -{ - typedef ComponentType Super; - -public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - - ISlangUnknown* getInterface(const Guid& guid); - - - // Forward `IComponentType` methods - - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE - { - return Super::getSession(); - } - - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL - getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getLayout(targetIndex, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetCode(targetIndex, outCode, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointMetadata( - entryPointIndex, - targetIndex, - outMetadata, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCompileResult( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::ICompileResult** outCompileResult, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCompileResult( - entryPointIndex, - targetIndex, - outCompileResult, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCompileResult( - SlangInt targetIndex, - slang::ICompileResult** outCompileResult, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetCompileResult(targetIndex, outCompileResult, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( - SlangInt entryPointIndex, - SlangInt targetIndex, - ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE - { - return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::specialize( - specializationArgs, - specializationArgCount, - outSpecializedComponentType, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL - renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE - { - return Super::renameEntryPoint(newName, outEntryPoint); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL - link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::link(outLinkedComponentType, outDiagnostics); - } - - virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( - slang::IComponentType** outLinkedComponentType, - uint32_t count, - slang::CompilerOptionEntry* entries, - ISlangBlob** outDiagnostics) override - { - return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointHostCallable( - entryPointIndex, - targetIndex, - outSharedLibrary, - outDiagnostics); - } - - SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outHash) SLANG_OVERRIDE - { - return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); - } - - virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE; - - /// Create an entry point that refers to the given function. - static RefPtr<EntryPoint> create( - Linkage* linkage, - DeclRef<FuncDecl> funcDeclRef, - Profile profile); - - /// Get the function decl-ref, including any generic arguments. - DeclRef<FuncDecl> getFuncDeclRef() { return m_funcDeclRef; } - - /// Get the function declaration (without generic arguments). - FuncDecl* getFuncDecl() { return m_funcDeclRef.getDecl(); } - - /// Get the name of the entry point - Name* getName() { return m_name; } - - /// Get the profile associated with the entry point - /// - /// Note: only the stage part of the profile is expected - /// to contain useful data, but certain legacy code paths - /// allow for "shader model" information to come via this path. - /// - Profile getProfile() { return m_profile; } - - /// Get the stage that the entry point is for. - Stage getStage() { return m_profile.getStage(); } - - /// Get the module that contains the entry point. - Module* getModule(); - - /// Get a list of modules that this entry point depends on. - /// - /// This will include the module that defines the entry point (see `getModule()`), - /// but may also include modules that are required by its generic type arguments. - /// - List<Module*> const& getModuleDependencies() - SLANG_OVERRIDE; // { return getModule()->getModuleDependencies(); } - List<SourceFile*> const& getFileDependencies() - SLANG_OVERRIDE; // { return getModule()->getFileDependencies(); } - - /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. - static RefPtr<EntryPoint> createDummyForPassThrough( - Linkage* linkage, - Name* name, - Profile profile); - - /// Create a dummy `EntryPoint` that stands in for a serialized entry point - static RefPtr<EntryPoint> createDummyForDeserialize( - Linkage* linkage, - Name* name, - Profile profile, - String mangledName); - - /// Get the number of existential type parameters for the entry point. - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE; - - /// Get the existential type parameter at `index`. - SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; - - Index getRequirementCount() SLANG_OVERRIDE; - RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; - - SpecializationParams const& getExistentialSpecializationParams() - { - return m_existentialSpecializationParams; - } - - Index getGenericSpecializationParamCount() { return m_genericSpecializationParams.getCount(); } - Index getExistentialSpecializationParamCount() - { - return m_existentialSpecializationParams.getCount(); - } - - /// Get an array of all entry-point shader parameters. - List<ShaderParamInfo> const& getShaderParams() { return m_shaderParams; } - - Index getEntryPointCount() SLANG_OVERRIDE { return 1; }; - RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - return this; - } - String getEntryPointMangledName(Index index) SLANG_OVERRIDE; - String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; - - Index getShaderParamCount() SLANG_OVERRIDE { return 0; } - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - return ShaderParamInfo(); - } - - class EntryPointSpecializationInfo : public SpecializationInfo - { - public: - DeclRef<FuncDecl> specializedFuncDeclRef; - List<ExpandedSpecializationArg> existentialSpecializationArgs; - }; - - SLANG_NO_THROW slang::FunctionReflection* SLANG_MCALL getFunctionReflection() SLANG_OVERRIDE - { - return (slang::FunctionReflection*)m_funcDeclRef.declRefBase; - } - -protected: - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) - SLANG_OVERRIDE; - - RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE; - -private: - EntryPoint(Linkage* linkage, Name* name, Profile profile, DeclRef<FuncDecl> funcDeclRef); - - void _collectGenericSpecializationParamsRec(Decl* decl); - void _collectShaderParams(); - - // The name of the entry point function (e.g., `main`) - // - Name* m_name = nullptr; - - // The declaration of the entry-point function itself. - // - DeclRef<FuncDecl> m_funcDeclRef; - - /// The mangled name of the entry point function - String m_mangledName; - - SpecializationParams m_genericSpecializationParams; - SpecializationParams m_existentialSpecializationParams; - - /// Information about entry-point parameters - List<ShaderParamInfo> m_shaderParams; - - // The profile that the entry point will be compiled for - // (this is a combination of the target stage, and also - // a feature level that sets capabilities) - // - // Note: the profile-version part of this should probably - // be moving towards deprecation, in favor of the version - // information (e.g., "Shader Model 5.1") always coming - // from the target, while the stage part is all that is - // intrinsic to the entry point. - // - Profile m_profile; -}; - -class TypeConformance : public ComponentType, public slang::ITypeConformance -{ - typedef ComponentType Super; - -public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - - ISlangUnknown* getInterface(const Guid& guid); - - TypeConformance( - Linkage* linkage, - SubtypeWitness* witness, - Int confomrmanceIdOverride, - DiagnosticSink* sink); - - // Forward `IComponentType` methods - - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE - { - return Super::getSession(); - } - - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL - getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getLayout(targetIndex, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetCode(targetIndex, outCode, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointMetadata( - entryPointIndex, - targetIndex, - outMetadata, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCompileResult( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::ICompileResult** outCompileResult, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCompileResult( - entryPointIndex, - targetIndex, - outCompileResult, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCompileResult( - SlangInt targetIndex, - slang::ICompileResult** outCompileResult, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetCompileResult(targetIndex, outCompileResult, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( - SlangInt entryPointIndex, - SlangInt targetIndex, - ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE - { - return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::specialize( - specializationArgs, - specializationArgCount, - outSpecializedComponentType, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL - renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE - { - return Super::renameEntryPoint(newName, outEntryPoint); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL - link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::link(outLinkedComponentType, outDiagnostics); - } - - virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( - slang::IComponentType** outLinkedComponentType, - uint32_t count, - slang::CompilerOptionEntry* entries, - ISlangBlob** outDiagnostics) override - { - return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointHostCallable( - entryPointIndex, - targetIndex, - outSharedLibrary, - outDiagnostics); - } - - SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outHash) SLANG_OVERRIDE - { - return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); - } - - virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE; - - List<Module*> const& getModuleDependencies() SLANG_OVERRIDE; - List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE; - - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } - - /// Get the existential type parameter at `index`. - SpecializationParam const& getSpecializationParam(Index /*index*/) SLANG_OVERRIDE - { - static SpecializationParam emptyParam; - return emptyParam; - } - - Index getRequirementCount() SLANG_OVERRIDE; - RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; - Index getEntryPointCount() SLANG_OVERRIDE { return 0; }; - RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - return nullptr; - } - String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; } - String getEntryPointNameOverride(Index /*index*/) SLANG_OVERRIDE { return ""; } - - Index getShaderParamCount() SLANG_OVERRIDE { return 0; } - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - return ShaderParamInfo(); - } - - SubtypeWitness* getSubtypeWitness() { return m_subtypeWitness; } - IRModule* getIRModule() { return m_irModule.Ptr(); } - -protected: - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) - SLANG_OVERRIDE; - - RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE; - -private: - SubtypeWitness* m_subtypeWitness; - ModuleDependencyList m_moduleDependencyList; - FileDependencyList m_fileDependencyList; - List<RefPtr<Module>> m_requirements; - HashSet<Module*> m_requirementSet; - RefPtr<IRModule> m_irModule; - Int m_conformanceIdOverride; - void addDepedencyFromWitness(SubtypeWitness* witness); -}; - -enum class PassThroughMode : SlangPassThroughIntegral -{ - None = SLANG_PASS_THROUGH_NONE, ///< don't pass through: use Slang compiler - Fxc = SLANG_PASS_THROUGH_FXC, ///< pass through HLSL to `D3DCompile` API - Dxc = SLANG_PASS_THROUGH_DXC, ///< pass through HLSL to `IDxcCompiler` API - Glslang = SLANG_PASS_THROUGH_GLSLANG, ///< pass through GLSL to `glslang` library - SpirvDis = SLANG_PASS_THROUGH_SPIRV_DIS, ///< pass through spirv-dis - Clang = SLANG_PASS_THROUGH_CLANG, ///< Pass through clang compiler - VisualStudio = SLANG_PASS_THROUGH_VISUAL_STUDIO, ///< Visual studio compiler - Gcc = SLANG_PASS_THROUGH_GCC, ///< Gcc compiler - GenericCCpp = SLANG_PASS_THROUGH_GENERIC_C_CPP, ///< Generic C/C++ compiler - NVRTC = SLANG_PASS_THROUGH_NVRTC, ///< NVRTC CUDA compiler - LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler' - SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt - MetalC = SLANG_PASS_THROUGH_METAL, - Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API - SpirvLink = SLANG_PASS_THROUGH_SPIRV_LINK, ///< pass through spirv to spirv-link - CountOf = SLANG_PASS_THROUGH_COUNT_OF, -}; -void printDiagnosticArg(StringBuilder& sb, PassThroughMode val); class SourceFile; -/// A module of code that has been compiled through the front-end -/// -/// A module comprises all the code from one translation unit (which -/// may span multiple Slang source files), and provides access -/// to both the AST and IR representations of that code. -/// -class Module : public ComponentType, public slang::IModule -{ - typedef ComponentType Super; - -public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - - ISlangUnknown* getInterface(const Guid& guid); - - - // Forward `IComponentType` methods - - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE - { - return Super::getSession(); - } - - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL - getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getLayout(targetIndex, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetCode(targetIndex, outCode, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( - SlangInt entryPointIndex, - SlangInt targetIndex, - ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE - { - return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::specialize( - specializationArgs, - specializationArgCount, - outSpecializedComponentType, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL - renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE - { - return Super::renameEntryPoint(newName, outEntryPoint); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL - link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::link(outLinkedComponentType, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointHostCallable( - entryPointIndex, - targetIndex, - outSharedLibrary, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL - findEntryPointByName(char const* name, slang::IEntryPoint** outEntryPoint) SLANG_OVERRIDE - { - if (outEntryPoint == nullptr) - { - return SLANG_E_INVALID_ARG; - } - SLANG_AST_BUILDER_RAII(m_astBuilder); - ComPtr<slang::IEntryPoint> entryPoint(findEntryPointByName(UnownedStringSlice(name))); - if ((!entryPoint)) - return SLANG_FAIL; - - *outEntryPoint = entryPoint.detach(); - return SLANG_OK; - } - - virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint( - char const* name, - SlangStage stage, - slang::IEntryPoint** outEntryPoint, - ISlangBlob** outDiagnostics) override - { - if (outEntryPoint == nullptr) - { - return SLANG_E_INVALID_ARG; - } - ComPtr<slang::IEntryPoint> entryPoint( - findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics)); - if ((!entryPoint)) - return SLANG_FAIL; - - *outEntryPoint = entryPoint.detach(); - return SLANG_OK; - } - - virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDefinedEntryPointCount() override - { - return (SlangInt32)m_entryPoints.getCount(); - } - - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint) override - { - if (index < 0 || index >= m_entryPoints.getCount()) - return SLANG_E_INVALID_ARG; - - if (outEntryPoint == nullptr) - { - return SLANG_E_INVALID_ARG; - } - - ComPtr<slang::IEntryPoint> entryPoint(m_entryPoints[index].Ptr()); - *outEntryPoint = entryPoint.detach(); - return SLANG_OK; - } - - virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( - slang::IComponentType** outLinkedComponentType, - uint32_t count, - slang::CompilerOptionEntry* entries, - ISlangBlob** outDiagnostics) override - { - return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); - } - // - - SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outHash) SLANG_OVERRIDE - { - return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointMetadata( - entryPointIndex, - targetIndex, - outMetadata, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCompileResult( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::ICompileResult** outCompileResult, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCompileResult( - entryPointIndex, - targetIndex, - outCompileResult, - outDiagnostics); - } - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCompileResult( - SlangInt targetIndex, - slang::ICompileResult** outCompileResult, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetCompileResult(targetIndex, outCompileResult, outDiagnostics); - } - - /// Get a serialized representation of the checked module. - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - serialize(ISlangBlob** outSerializedBlob) override; - - /// Write the serialized representation of this module to a file. - virtual SLANG_NO_THROW SlangResult SLANG_MCALL writeToFile(char const* fileName) override; - - /// Get the name of the module. - virtual SLANG_NO_THROW const char* SLANG_MCALL getName() override; - - /// Get the path of the module. - virtual SLANG_NO_THROW const char* SLANG_MCALL getFilePath() override; - - /// Get the unique identity of the module. - virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() override; - - /// Get the number of dependency files that this module depends on. - /// This includes both the explicit source files, as well as any - /// additional files that were transitively referenced (e.g., via - /// a `#include` directive). - virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDependencyFileCount() override; - - /// Get the path to a file this module depends on. - virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath(SlangInt32 index) override; - - - // IModulePrecompileService_Experimental - /// Precompile TU to target language - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - precompileForTarget(SlangCompileTarget target, slang::IBlob** outDiagnostics) override; - - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getPrecompiledTargetCode( - SlangCompileTarget target, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics = nullptr) override; - - virtual SLANG_NO_THROW SlangInt SLANG_MCALL getModuleDependencyCount() SLANG_OVERRIDE; - - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getModuleDependency( - SlangInt dependencyIndex, - slang::IModule** outModule, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - - virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE; - - virtual SLANG_NO_THROW slang::DeclReflection* SLANG_MCALL getModuleReflection() SLANG_OVERRIDE; - - void setDigest(SHA1::Digest const& digest) { m_digest = digest; } - SHA1::Digest computeDigest(); - - /// Create a module (initially empty). - Module(Linkage* linkage, ASTBuilder* astBuilder = nullptr); - - /// Get the AST for the module (if it has been parsed) - ModuleDecl* getModuleDecl() { return m_moduleDecl; } - - /// The the IR for the module (if it has been generated) - IRModule* getIRModule() { return m_irModule; } - - /// Get the list of other modules this module depends on - List<Module*> const& getModuleDependencyList() - { - return m_moduleDependencyList.getModuleList(); - } - - /// Get the list of files this module depends on - List<SourceFile*> const& getFileDependencyList() { return m_fileDependencyList.getFileList(); } - - /// Register a module that this module depends on - void addModuleDependency(Module* module); - - /// Register a source file that this module depends on - void addFileDependency(SourceFile* sourceFile); - - void clearFileDependency() { m_fileDependencyList.clear(); } - /// Set the AST for this module. - /// - /// This should only be called once, during creation of the module. - /// - void setModuleDecl(ModuleDecl* moduleDecl); // { m_moduleDecl = moduleDecl; } - - void setName(String name); - void setName(Name* name) { m_name = name; } - Name* getNameObj() { return m_name; } - - void setPathInfo(PathInfo pathInfo) { m_pathInfo = pathInfo; } - - /// Set the IR for this module. - /// - /// This should only be called once, during creation of the module. - /// - void setIRModule(IRModule* irModule) { m_irModule = irModule; } - - Index getEntryPointCount() SLANG_OVERRIDE { return 0; } - RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - return nullptr; - } - String getEntryPointMangledName(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - return String(); - } - String getEntryPointNameOverride(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - return String(); - } - - Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } - - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE - { - return m_specializationParams.getCount(); - } - SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE - { - return m_specializationParams[index]; - } - - Index getRequirementCount() SLANG_OVERRIDE; - RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; - - List<Module*> const& getModuleDependencies() SLANG_OVERRIDE - { - return m_moduleDependencyList.getModuleList(); - } - List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE - { - return m_fileDependencyList.getFileList(); - } - - /// Given a mangled name finds the exported NodeBase associated with this module. - /// If not found returns nullptr. - Decl* findExportedDeclByMangledName(const UnownedStringSlice& mangledName); - - /// Ensure that the any accelerator(s) used for `findExportedDeclByMangledName` - /// have already been built. - /// - void ensureExportLookupAcceleratorBuilt(); - - Count getExportedDeclCount(); - Decl* getExportedDecl(Index index); - UnownedStringSlice getExportedDeclMangledName(Index index); - - /// Get the ASTBuilder - ASTBuilder* getASTBuilder() { return m_astBuilder; } - - /// Collect information on the shader parameters of the module. - /// - /// This method should only be called once, after the core - /// structured of the module (its AST and IR) have been created, - /// and before any of the `ComponentType` APIs are used. - /// - /// TODO: We might eventually consider a non-stateful approach - /// to constructing a `Module`. - /// - void _collectShaderParams(); - - void _discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets); - void _discoverEntryPointsImpl( - ContainerDecl* containerDecl, - DiagnosticSink* sink, - const List<RefPtr<TargetRequest>>& targets); - - - class ModuleSpecializationInfo : public SpecializationInfo - { - public: - struct GenericArgInfo - { - Decl* paramDecl = nullptr; - Val* argVal = nullptr; - }; - - List<GenericArgInfo> genericArgs; - List<ExpandedSpecializationArg> existentialArgs; - }; - - RefPtr<EntryPoint> findEntryPointByName(UnownedStringSlice const& name); - RefPtr<EntryPoint> findAndCheckEntryPoint( - UnownedStringSlice const& name, - SlangStage stage, - ISlangBlob** outDiagnostics); - - List<RefPtr<EntryPoint>>& getEntryPoints() { return m_entryPoints; } - void _addEntryPoint(EntryPoint* entryPoint); - void _processFindDeclsExportSymbolsRec(Decl* decl); - - // Gets the files that has been included into the module. - Dictionary<SourceFile*, FileDecl*>& getIncludedSourceFileMap() - { - return m_mapSourceFileToFileDecl; - } - -protected: - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) - SLANG_OVERRIDE; - - RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE; - -private: - Name* m_name = nullptr; - PathInfo m_pathInfo; - - // The AST for the module - ModuleDecl* m_moduleDecl = nullptr; - - // The IR for the module - RefPtr<IRModule> m_irModule = nullptr; - - List<ShaderParamInfo> m_shaderParams; - SpecializationParams m_specializationParams; - - List<Module*> m_requirements; - - // A digest that uniquely identifies the contents of the module. - SHA1::Digest m_digest; - - // List of modules this module depends on - ModuleDependencyList m_moduleDependencyList; - - // List of source files this module depends on - FileDependencyList m_fileDependencyList; - - // Entry points that were defined in this module - // - // Note: the entry point defined in the module are *not* - // part of the memory image/layout of the module when - // it is considered as an IComponentType. This can be - // a bit confusing, but if all the entry points in the - // module were automatically linked into the component - // type, we'd need a way to access just the global - // scope of the module without the entry points, in - // case we wanted to link a single entry point against - // the global scope. The `Module` type provides exactly - // that "module without its entry points" unit of - // granularity for linking. - // - // This list only exists for lookup purposes, so that - // the user can find an existing entry-point function - // that was defined as part of the module. - // - List<RefPtr<EntryPoint>> m_entryPoints; - - // The builder that owns all of the AST nodes from parsing the source of - // this module. - RefPtr<ASTBuilder> m_astBuilder; - - // Holds map of exported mangled names to symbols. m_mangledExportPool maps names to indices, - // and m_mangledExportSymbols holds the NodeBase* values for each index. - StringSlicePool m_mangledExportPool; - List<Decl*> m_mangledExportSymbols; - - // Source files that have been pulled into the module with `__include`. - Dictionary<SourceFile*, FileDecl*> m_mapSourceFileToFileDecl; - -public: - SLANG_NO_THROW SlangResult SLANG_MCALL disassemble(slang::IBlob** outDisassembledBlob) override - { - if (!outDisassembledBlob) - return SLANG_E_INVALID_ARG; - String disassembly; - this->getIRModule()->getModuleInst()->dump(disassembly); - auto blob = StringUtil::createStringBlob(disassembly); - *outDisassembledBlob = blob.detach(); - return SLANG_OK; - } -}; -typedef Module LoadedModule; - -/// A request for the front-end to compile a translation unit. -class TranslationUnitRequest : public RefObject -{ -public: - TranslationUnitRequest(FrontEndCompileRequest* compileRequest); - TranslationUnitRequest(FrontEndCompileRequest* compileRequest, Module* m); - - // The parent compile request - FrontEndCompileRequest* compileRequest = nullptr; - - // The language in which the source file(s) - // are assumed to be written - SourceLanguage sourceLanguage = SourceLanguage::Unknown; - - /// Makes any source artifact available as a SourceFile. - /// If successful any of the source artifacts will be represented by the same index - /// of sourceArtifacts - SlangResult requireSourceFiles(); - - /// Get the source files. - /// Since lazily evaluated requires calling requireSourceFiles to know it's in sync - /// with sourceArtifacts. - List<SourceFile*> const& getSourceFiles(); - - /// Get the source artifacts associated - const List<ComPtr<IArtifact>>& getSourceArtifacts() const { return m_sourceArtifacts; } - - /// Clear all of the source - void clearSource() - { - m_sourceArtifacts.clear(); - m_sourceFiles.clear(); - m_includedFileSet.clear(); - } - - /// Add a source artifact - void addSourceArtifact(IArtifact* sourceArtifact); - - /// Add both the artifact and the sourceFile. - void addSource(IArtifact* sourceArtifact, SourceFile* sourceFile); - - void addIncludedSourceFileIfNotExist(SourceFile* sourceFile); - - // The entry points associated with this translation unit - List<RefPtr<EntryPoint>> const& getEntryPoints() { return module->getEntryPoints(); } - - void _addEntryPoint(EntryPoint* entryPoint) { module->_addEntryPoint(entryPoint); } - - // Preprocessor definitions to use for this translation unit only - // (whereas the ones on `compileRequest` will be shared) - Dictionary<String, String> preprocessorDefinitions; - - /// The name that will be used for the module this translation unit produces. - Name* moduleName = nullptr; - - /// Result of compiling this translation unit (a module) - RefPtr<Module> module; - - bool isChecked = false; - - Module* getModule() { return module; } - ModuleDecl* getModuleDecl() { return module->getModuleDecl(); } - - Session* getSession(); - NamePool* getNamePool(); - SourceManager* getSourceManager(); - - Scope* getLanguageScope(); - - Dictionary<String, String> getCombinedPreprocessorDefinitions(); - - void setModuleName(Name* name) - { - moduleName = name; - if (module) - module->setName(name); - } - -protected: - void _addSourceFile(SourceFile* sourceFile); - /* Given an artifact, find a PathInfo. - If no PathInfo can be found will return an unknown PathInfo */ - PathInfo _findSourcePathInfo(IArtifact* artifact); - - List<ComPtr<IArtifact>> m_sourceArtifacts; - // The source file(s) that will be compiled to form this translation unit - // - // Usually, for HLSL or GLSL there will be only one file. - // NOTE! This member is generated lazily from m_sourceArtifacts - // it is *necessary* to call requireSourceFiles to ensure it's in sync. - List<SourceFile*> m_sourceFiles; - - // Track all the included source files added in m_sourceFiles - HashSet<SourceFile*> m_includedFileSet; -}; enum class FloatingPointMode : SlangFloatingPointModeIntegral { @@ -2044,105 +161,6 @@ enum class FloatingPointDenormalMode : SlangFpDenormalModeIntegral FlushToZero = SLANG_FP_DENORM_MODE_FTZ, }; -enum class WriterChannel : SlangWriterChannelIntegral -{ - Diagnostic = SLANG_WRITER_CHANNEL_DIAGNOSTIC, - StdOutput = SLANG_WRITER_CHANNEL_STD_OUTPUT, - StdError = SLANG_WRITER_CHANNEL_STD_ERROR, - CountOf = SLANG_WRITER_CHANNEL_COUNT_OF, -}; - -enum class WriterMode : SlangWriterModeIntegral -{ - Text = SLANG_WRITER_MODE_TEXT, - Binary = SLANG_WRITER_MODE_BINARY, -}; - -class TargetRequest; - -/// Are we generating code for a D3D API? -bool isD3DTarget(TargetRequest* targetReq); - -// Are we generating code for Metal? -bool isMetalTarget(TargetRequest* targetReq); - -/// Are we generating code for a Khronos API (OpenGL or Vulkan)? -bool isKhronosTarget(TargetRequest* targetReq); -bool isKhronosTarget(CodeGenTarget target); - -/// Are we generating code for a CUDA API (CUDA / OptiX)? -bool isCUDATarget(TargetRequest* targetReq); - -// Are we generating code for a CPU target -bool isCPUTarget(TargetRequest* targetReq); - -/// Are we generating code for the WebGPU API? -bool isWGPUTarget(TargetRequest* targetReq); -bool isWGPUTarget(CodeGenTarget target); - -/// A request to generate output in some target format. -class TargetRequest : public RefObject -{ -public: - TargetRequest(Linkage* linkage, CodeGenTarget format); - - TargetRequest(const TargetRequest& other); - - Linkage* getLinkage() { return linkage; } - - Session* getSession(); - - CodeGenTarget getTarget() - { - return optionSet.getEnumOption<CodeGenTarget>(CompilerOptionName::Target); - } - - // TypeLayouts created on the fly by reflection API - struct TypeLayoutKey - { - Type* type; - slang::LayoutRules rules; - HashCode getHashCode() const - { - Hasher hasher; - hasher.hashValue(type); - hasher.hashValue(rules); - return hasher.getResult(); - } - bool operator==(TypeLayoutKey other) const - { - return type == other.type && rules == other.rules; - } - }; - Dictionary<TypeLayoutKey, RefPtr<TypeLayout>> typeLayouts; - - Dictionary<TypeLayoutKey, RefPtr<TypeLayout>>& getTypeLayouts() { return typeLayouts; } - - TypeLayout* getTypeLayout(Type* type, slang::LayoutRules rules); - - CompilerOptionSet& getOptionSet() { return optionSet; } - - CapabilitySet getTargetCaps(); - - void setTargetCaps(CapabilitySet capSet); - - HLSLToVulkanLayoutOptions* getHLSLToVulkanLayoutOptions(); - -private: - Linkage* linkage = nullptr; - CompilerOptionSet optionSet; - CapabilitySet cookedCapabilities; - RefPtr<HLSLToVulkanLayoutOptions> hlslToVulkanOptions; -}; - -/// Given a target request returns which (if any) intermediate source language is required -/// to produce it. -/// -/// If no intermediate source language is required, will return SourceLanguage::Unknown -SourceLanguage getIntermediateSourceLanguageForTarget(TargetProgram* req); - -/// Are resource types "bindless" (implemented as ordinary data) on the given `target`? -bool areResourceTypesBindlessOnTarget(TargetRequest* target); // Compute the "effective" profile to use when outputting the given entry point // for the chosen code-generation target. @@ -2160,1721 +178,12 @@ bool areResourceTypesBindlessOnTarget(TargetRequest* target); Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target); -/// Given a target returns the required downstream compiler -PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target); -/// Given a target returns a downstream compiler the prelude should be taken from. -SourceLanguage getDefaultSourceLanguageForDownstreamCompiler(PassThroughMode compiler); - /// Get the build tag string const char* getBuildTagString(); -struct TypeCheckingCache; - -struct ContainerTypeKey -{ - slang::TypeReflection* elementType; - slang::ContainerType containerType; - bool operator==(ContainerTypeKey other) const - { - return elementType == other.elementType && containerType == other.containerType; - } - Slang::HashCode getHashCode() const - { - return Slang::combineHash( - Slang::getHashCode(elementType), - Slang::getHashCode(containerType)); - } -}; - -/// A dictionary of modules to be considered when resolving `import`s, -/// beyond those that would normally be found through a `Linkage`. -/// -/// Checking of an `import` declaration will bottleneck through -/// `Linkage::findOrImportModule`, which would usually just check for -/// any module that had been previously loaded into the same `Linkage` -/// (e.g., by a call to `Linkage::loadModule()`). -/// -/// In the case where compilation is being done through an -/// explicit `FrontEndCompileRequest` or `EndToEndCompileRequest`, -/// the modules being compiled by that request do not get added to -/// the surrounding `Linkage`. -/// -/// There is a corner case when an explicit compile request has -/// multiple `TranslationUnitRequest`s, because the user (reasonably) -/// expects that if they compile `A.slang` and `B.slang` as two -/// distinct translation units in the same compile request, then -/// an `import B` inside of `A.slang` should resolve to reference -/// the code of `B.slang`. But because neither `A` nor `B` gets -/// added to the `Linkage`, and the `Linkage` is what usually -/// determines what is or isn't loaded, that intuition will -/// be wrong, without a bit of help. -/// -/// The `LoadedModuleDictionary` is thus filled in by a -/// `FrontEndCompileRequest` to collect the modules it is compiling, -/// so that they can cross-reference one another (albeit with -/// a current implementation restriction that modules in the -/// request can only `import` those earlier in the request...). -/// -/// The dictionary then gets passed around between nearly all of -/// the operations that deal with loading modules, to make sure -/// that they can detect a previously loaded module. -/// -typedef Dictionary<Name*, Module*> LoadedModuleDictionary; - -enum ModuleBlobType -{ - Source, - IR -}; - -/// A context for loading and re-using code modules. -class Linkage : public RefObject, public slang::ISession -{ -public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - - CompilerOptionSet m_optionSet; - - ISlangUnknown* getInterface(const Guid& guid); - - SLANG_NO_THROW slang::IGlobalSession* SLANG_MCALL getGlobalSession() override; - SLANG_NO_THROW slang::IModule* SLANG_MCALL - loadModule(const char* moduleName, slang::IBlob** outDiagnostics = nullptr) override; - slang::IModule* loadModuleFromBlob( - const char* moduleName, - const char* path, - slang::IBlob* source, - ModuleBlobType blobType, - slang::IBlob** outDiagnostics = nullptr); - SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromIRBlob( - const char* moduleName, - const char* path, - slang::IBlob* source, - slang::IBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW SlangResult SLANG_MCALL loadModuleInfoFromIRBlob( - slang::IBlob* source, - SlangInt& outModuleVersion, - const char*& outModuleCompilerVersion, - const char*& outModuleName) override; - SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromSource( - const char* moduleName, - const char* path, - slang::IBlob* source, - slang::IBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromSourceString( - const char* moduleName, - const char* path, - const char* string, - slang::IBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW SlangResult SLANG_MCALL createCompositeComponentType( - slang::IComponentType* const* componentTypes, - SlangInt componentTypeCount, - slang::IComponentType** outCompositeComponentType, - ISlangBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL specializeType( - slang::TypeReflection* type, - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - ISlangBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getTypeLayout( - slang::TypeReflection* type, - SlangInt targetIndex = 0, - slang::LayoutRules rules = slang::LayoutRules::Default, - ISlangBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL getContainerType( - slang::TypeReflection* elementType, - slang::ContainerType containerType, - ISlangBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL getDynamicType() override; - SLANG_NO_THROW SlangResult SLANG_MCALL - getTypeRTTIMangledName(slang::TypeReflection* type, ISlangBlob** outNameBlob) override; - SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessMangledName( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - ISlangBlob** outNameBlob) override; - SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessSequentialID( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - uint32_t* outId) override; - SLANG_NO_THROW SlangResult SLANG_MCALL getDynamicObjectRTTIBytes( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - uint32_t* outBytes, - uint32_t bufferSize) override; - SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - slang::ITypeConformance** outConformance, - SlangInt conformanceIdOverride, - ISlangBlob** outDiagnostics) override; - SLANG_NO_THROW SlangResult SLANG_MCALL - createCompileRequest(SlangCompileRequest** outCompileRequest) override; - virtual SLANG_NO_THROW SlangInt SLANG_MCALL getLoadedModuleCount() override; - virtual SLANG_NO_THROW slang::IModule* SLANG_MCALL getLoadedModule(SlangInt index) override; - virtual SLANG_NO_THROW bool SLANG_MCALL - isBinaryModuleUpToDate(const char* modulePath, slang::IBlob* binaryModuleBlob) override; - - // Updates the supplied builder with linkage-related information, which includes preprocessor - // defines, the compiler version, and other compiler options. This is then merged with the hash - // produced for the program to produce a key that can be used with the shader cache. - void buildHash(DigestBuilder<SHA1>& builder, SlangInt targetIndex = -1); - - void addTarget(slang::TargetDesc const& desc); - SlangResult addSearchPath(char const* path); - SlangResult addPreprocessorDefine(char const* name, char const* value); - SlangResult setMatrixLayoutMode(SlangMatrixLayoutMode mode); - /// Create an initially-empty linkage - Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage); - - /// Dtor - ~Linkage(); - - bool isInLanguageServer() - { - return contentAssistInfo.checkingMode != ContentAssistCheckingMode::None; - } - - /// Get the parent session for this linkage - Session* getSessionImpl() { return m_session; } - - // Information on the targets we are being asked to - // generate code for. - List<RefPtr<TargetRequest>> targets; - - // Directories to search for `#include` files or `import`ed modules - SearchDirectoryList& getSearchDirectories(); - - // Source manager to help track files loaded - SourceManager m_defaultSourceManager; - SourceManager* m_sourceManager = nullptr; - RefPtr<CommandLineContext> m_cmdLineContext; - - // Used to store strings returned by the api as const char* - StringSlicePool m_stringSlicePool; - - // Name pool for looking up names - NamePool* namePool = nullptr; - - NamePool* getNamePool() { return namePool; } - - ASTBuilder* getASTBuilder() { return m_astBuilder; } - - RefPtr<ASTBuilder> m_astBuilder; - - // Cache for container types. - Dictionary<ContainerTypeKey, Type*> m_containerTypes; - - // cache used by type checking, implemented in check.cpp - TypeCheckingCache* getTypeCheckingCache(); - void destroyTypeCheckingCache(); - - RefPtr<RefObject> m_typeCheckingCache = nullptr; - - // Modules that have been dynamically loaded via `import` - // - // This is a list of unique modules loaded, in the order they were encountered. - List<RefPtr<LoadedModule>> loadedModulesList; - - // Map from the path (or uniqueIdentity if available) of a module file to its definition - Dictionary<String, RefPtr<LoadedModule>> mapPathToLoadedModule; - - // Map from the logical name of a module to its definition - Dictionary<Name*, RefPtr<LoadedModule>> mapNameToLoadedModules; - - // Map from the mangled name of RTTI objects to sequential IDs - // used by `switch`-based dynamic dispatch. - Dictionary<String, uint32_t> mapMangledNameToRTTIObjectIndex; - - // Counters for allocating sequential IDs to witness tables conforming to each interface type. - Dictionary<String, uint32_t> mapInterfaceMangledNameToSequentialIDCounters; - - SearchDirectoryList searchDirectoryCache; - - // The resulting specialized IR module for each entry point request - List<RefPtr<IRModule>> compiledModules; - - ContentAssistInfo contentAssistInfo; - - /// File system implementation to use when loading files from disk. - /// - /// If this member is `null`, a default implementation that tries - /// to use the native OS filesystem will be used instead. - /// - ComPtr<ISlangFileSystem> m_fileSystem; - - /// The extended file system implementation. Will be set to a default implementation - /// if fileSystem is nullptr. Otherwise it will either be fileSystem's interface, - /// or a wrapped impl that makes fileSystem operate as fileSystemExt - ComPtr<ISlangFileSystemExt> m_fileSystemExt; - - /// Get the currenly set file system - ISlangFileSystemExt* getFileSystemExt() { return m_fileSystemExt; } - - /// Load a file into memory using the configured file system. - /// - /// @param path The path to attempt to load from - /// @param outBlob A destination pointer to receive the loaded blob - /// @returns A `SlangResult` to indicate success or failure. - /// - SlangResult loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob); - - Expr* parseTermString(String str, Scope* scope); - - Type* specializeType( - Type* unspecializedType, - Int argCount, - Type* const* args, - DiagnosticSink* sink); - - /// Add a new target and return its index. - UInt addTarget(CodeGenTarget target); - - /// "Bottleneck" routine for loading a module. - /// - /// All attempts to load a module, whether through - /// Slang API calls, `import` operations, or other - /// means, should bottleneck through `loadModuleImpl`, - /// or one of the specialized cases `loadSourceModuleImpl` - /// and `loadBinaryModuleImpl`. - /// - RefPtr<Module> loadModuleImpl( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* fileContentsBlob, - SourceLoc const& loc, - DiagnosticSink* sink, - const LoadedModuleDictionary* additionalLoadedModules, - ModuleBlobType blobType); - - RefPtr<Module> loadSourceModuleImpl( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* fileContentsBlob, - SourceLoc const& loc, - DiagnosticSink* sink, - const LoadedModuleDictionary* additionalLoadedModules); - - RefPtr<Module> loadBinaryModuleImpl( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* fileContentsBlob, - SourceLoc const& loc, - DiagnosticSink* sink); - - /// Either finds a previously-loaded module matching what - /// was serialized into `moduleChunk`, or else attempts - /// to load the serialized module. - /// - /// If a previously-loaded module is found that matches the - /// name or path information in `moduleChunk`, then that - /// previously-loaded module is returned. - /// - /// Othwerise, attempts to load a module from `moduleChunk` - /// and, if successful, returns the freshly loaded module. - /// - /// Otherwise, return null. - /// - RefPtr<Module> findOrLoadSerializedModuleForModuleLibrary( - ISlangBlob* blobHoldingSerializedData, - ModuleChunk const* moduleChunk, - RIFF::ListChunk const* libraryChunk, - DiagnosticSink* sink); - - RefPtr<Module> loadSerializedModule( - Name* moduleName, - const PathInfo& moduleFilePathInfo, - ISlangBlob* blobHoldingSerializedData, - ModuleChunk const* moduleChunk, - RIFF::ListChunk const* containerChunk, //< The outer container, if there is one. - SourceLoc const& requestingLoc, - DiagnosticSink* sink); - - SlangResult loadSerializedModuleContents( - Module* module, - const PathInfo& moduleFilePathInfo, - ISlangBlob* blobHoldingSerializedData, - ModuleChunk const* moduleChunk, - RIFF::ListChunk const* containerChunk, //< The outer container, if there is one. - DiagnosticSink* sink); - - SourceFile* loadSourceFile(String pathFrom, String path); - - void loadParsedModule( - RefPtr<FrontEndCompileRequest> compileRequest, - RefPtr<TranslationUnitRequest> translationUnit, - Name* name, - PathInfo const& pathInfo); - - bool isBinaryModuleUpToDate(String fromPath, RIFF::ListChunk const* baseChunk); - - RefPtr<Module> findOrImportModule( - Name* name, - SourceLoc const& loc, - DiagnosticSink* sink, - const LoadedModuleDictionary* loadedModules = nullptr); - - SourceFile* findFile(Name* name, SourceLoc loc, IncludeSystem& outIncludeSystem); - struct IncludeResult - { - FileDecl* fileDecl; - bool isNew; - }; - IncludeResult findAndIncludeFile( - Module* module, - TranslationUnitRequest* translationUnit, - Name* name, - SourceLoc const& loc, - DiagnosticSink* sink); - - SourceManager* getSourceManager() { return m_sourceManager; } - - /// Override the source manager for the linkage. - /// - /// This is only used to install a temporary override when - /// parsing stuff from strings (where we don't want to retain - /// full source files for the parsed result). - /// - /// TODO: We should remove the need for this hack. - /// - void setSourceManager(SourceManager* sourceManager) { m_sourceManager = sourceManager; } - - void setRequireCacheFileSystem(bool requireCacheFileSystem); - - void setFileSystem(ISlangFileSystem* fileSystem); - - DeclRef<Decl> specializeGeneric( - DeclRef<Decl> declRef, - List<Expr*> argExprs, - DiagnosticSink* sink); - - DeclRef<Decl> specializeWithArgTypes( - Expr* funcExpr, - List<Type*> argTypes, - DiagnosticSink* sink); - - bool isSpecialized(DeclRef<Decl> declRef); - - DiagnosticSink::Flags diagnosticSinkFlags = 0; - - bool m_requireCacheFileSystem = false; - - // Modules that have been read in with the -r option - List<ComPtr<IArtifact>> m_libModules; - - void _stopRetainingParentSession() { m_retainedSession = nullptr; } - - // Get shared semantics information for reflection purposes. - SharedSemanticsContext* getSemanticsForReflection(); - -private: - /// The global Slang library session that this linkage is a child of - Session* m_session = nullptr; - - RefPtr<Session> m_retainedSession; - - /// Tracks state of modules currently being loaded. - /// - /// This information is used to diagnose cases where - /// a user tries to recursively import the same module - /// (possibly along a transitive chain of `import`s). - /// - struct ModuleBeingImportedRAII - { - public: - ModuleBeingImportedRAII( - Linkage* linkage, - Module* module, - Name* name, - SourceLoc const& importLoc) - : linkage(linkage), module(module), name(name), importLoc(importLoc) - { - next = linkage->m_modulesBeingImported; - linkage->m_modulesBeingImported = this; - } - - ~ModuleBeingImportedRAII() { linkage->m_modulesBeingImported = next; } - - Linkage* linkage; - Module* module; - Name* name; - SourceLoc importLoc; - ModuleBeingImportedRAII* next; - }; - - // Any modules currently being imported will be listed here - ModuleBeingImportedRAII* m_modulesBeingImported = nullptr; - - /// Is the given module in the middle of being imported? - bool isBeingImported(Module* module); - - /// Diagnose that an error occured in the process of importing a module - void _diagnoseErrorInImportedModule(DiagnosticSink* sink); - - List<Type*> m_specializedTypes; - - RefPtr<SharedSemanticsContext> m_semanticsForReflection; -}; - -/// Shared functionality between front- and back-end compile requests. -/// -/// This is the base class for both `FrontEndCompileRequest` and -/// `BackEndCompileRequest`, and allows a small number of parts of -/// the compiler to be easily invocable from either front-end or -/// back-end work. -/// -class CompileRequestBase : public RefObject -{ - // TODO: We really shouldn't need this type in the long run. - // The few places that rely on it should be refactored to just - // depend on the underlying information (a linkage and a diagnostic - // sink) directly. - // - // The flags to control dumping and validation of IR should be - // moved to some kind of shared settings/options `struct` that - // both front-end and back-end requests can store. - -public: - Session* getSession(); - Linkage* getLinkage() { return m_linkage; } - DiagnosticSink* getSink() { return m_sink; } - SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } - NamePool* getNamePool() { return getLinkage()->getNamePool(); } - ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } - SlangResult loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob) - { - return getLinkage()->loadFile(path, outPathInfo, outBlob); - } - -protected: - CompileRequestBase(Linkage* linkage, DiagnosticSink* sink); - -private: - Linkage* m_linkage = nullptr; - DiagnosticSink* m_sink = nullptr; -}; - -/// A request to compile source code to an AST + IR. -class FrontEndCompileRequest : public CompileRequestBase -{ -public: - /// Note that writers can be parsed as nullptr to disable output, - /// and individual channels set to null to disable them - FrontEndCompileRequest(Linkage* linkage, StdWriters* writers, DiagnosticSink* sink); - - int addEntryPoint(int translationUnitIndex, String const& name, Profile entryPointProfile); - - // Translation units we are being asked to compile - List<RefPtr<TranslationUnitRequest>> translationUnits; - - // Additional modules that needs to be made visible to `import` while checking. - const LoadedModuleDictionary* additionalLoadedModules = nullptr; - - RefPtr<TranslationUnitRequest> getTranslationUnit(UInt index) - { - return translationUnits[index]; - } - - // If true will serialize and de-serialize with debug information - bool verifyDebugSerialization = false; - - CompilerOptionSet optionSet; - - List<RefPtr<FrontEndEntryPointRequest>> m_entryPointReqs; - - List<RefPtr<FrontEndEntryPointRequest>> const& getEntryPointReqs() { return m_entryPointReqs; } - UInt getEntryPointReqCount() { return m_entryPointReqs.getCount(); } - FrontEndEntryPointRequest* getEntryPointReq(UInt index) { return m_entryPointReqs[index]; } - - void parseTranslationUnit(TranslationUnitRequest* translationUnit); - - // Perform primary semantic checking on all - // of the translation units in the program - void checkAllTranslationUnits(); - - void checkEntryPoints(); - - void generateIR(); - - SlangResult executeActionsInner(); - - /// Add a translation unit to be compiled. - /// - /// @param language The source language that the translation unit will use (e.g., - /// `SourceLanguage::Slang` - /// @param moduleName The name that will be used for the module compile from the translation - /// unit. - /// - /// If moduleName is passed as nullptr a module name is generated. - /// If all translation units in a compile request use automatically generated - /// module names, then they are guaranteed not to conflict with one another. - /// - /// @return The zero-based index of the translation unit in this compile request. - int addTranslationUnit(SourceLanguage language, Name* moduleName); - - int addTranslationUnit(TranslationUnitRequest* translationUnit); - - void addTranslationUnitSourceArtifact(int translationUnitIndex, IArtifact* sourceArtifact); - - void addTranslationUnitSourceBlob( - int translationUnitIndex, - String const& path, - ISlangBlob* sourceBlob); - - void addTranslationUnitSourceFile(int translationUnitIndex, String const& path); - - /// Get a component type that represents the global scope of the compile request. - ComponentType* getGlobalComponentType() { return m_globalComponentType; } - - /// Get a component type that represents the global scope of the compile request, plus the - /// requested entry points. - ComponentType* getGlobalAndEntryPointsComponentType() - { - return m_globalAndEntryPointsComponentType; - } - - List<RefPtr<ComponentType>> const& getUnspecializedEntryPoints() - { - return m_unspecializedEntryPoints; - } - - /// Does the code we are compiling represent part of the Slang core module? - bool m_isCoreModuleCode = false; - - Name* m_defaultModuleName = nullptr; - - /// The irDumpOptions - IRDumpOptions m_irDumpOptions; - - /// An "extra" entry point that was added via a library reference - struct ExtraEntryPointInfo - { - Name* name; - Profile profile; - String mangledName; - }; - - /// A list of "extra" entry points added via a library reference - List<ExtraEntryPointInfo> m_extraEntryPoints; - -private: - /// A component type that includes only the global scopes of the translation unit(s) that were - /// compiled. - RefPtr<ComponentType> m_globalComponentType; - - /// A component type that extends the global scopes with all of the entry points that were - /// specified. - RefPtr<ComponentType> m_globalAndEntryPointsComponentType; - - List<RefPtr<ComponentType>> m_unspecializedEntryPoints; - - RefPtr<StdWriters> m_writers; -}; - -/// A visitor for use with `ComponentType`s, allowing dispatch over the concrete subclasses. -class ComponentTypeVisitor -{ -public: - // The following methods should be overriden in a concrete subclass - // to customize how it acts on each of the concrete types of component. - // - // In cases where the application wants to simply "recurse" on a - // composite, specialized, or legacy component type it can use - // the `visitChildren` methods below. - // - virtual void visitEntryPoint( - EntryPoint* entryPoint, - EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; - virtual void visitModule( - Module* module, - Module::ModuleSpecializationInfo* specializationInfo) = 0; - virtual void visitComposite( - CompositeComponentType* composite, - CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; - virtual void visitSpecialized(SpecializedComponentType* specialized) = 0; - virtual void visitTypeConformance(TypeConformance* conformance) = 0; - virtual void visitRenamedEntryPoint( - RenamedEntryPointComponentType* renamedEntryPoint, - EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; - -protected: - // These helpers can be used to recurse into the logical children of a - // component type, and are useful for the common case where a visitor - // only cares about a few leaf cases. - // - void visitChildren( - CompositeComponentType* composite, - CompositeComponentType::CompositeSpecializationInfo* specializationInfo); - void visitChildren(SpecializedComponentType* specialized); -}; - -/// A `TargetProgram` represents a `ComponentType` specialized for a particular `TargetRequest` -/// -/// TODO: This should probably be renamed to `TargetComponentType`. -/// -/// By binding a component type to a specific target, a `TargetProgram` allows -/// for things like layout to be computed, that fundamentally depend on -/// the choice of target. -/// -/// A `TargetProgram` handles request for compiled kernel code for -/// entry point functions. In practice, kernel code can only be -/// correctly generated when the underlying `ComponentType` is "fully linked" -/// (has no remaining unsatisfied requirements). -/// -class TargetProgram : public RefObject -{ -public: - TargetProgram(ComponentType* componentType, TargetRequest* targetReq); - - /// Get the underlying program - ComponentType* getProgram() { return m_program; } - - /// Get the underlying target - TargetRequest* getTargetReq() { return m_targetReq; } - - /// Get the layout for the program on the target. - /// - /// If this is the first time the layout has been - /// requested, report any errors that arise during - /// layout to the given `sink`. - /// - ProgramLayout* getOrCreateLayout(DiagnosticSink* sink); - - /// Get the layout for the program on the target. - /// - /// This routine assumes that `getOrCreateLayout` - /// has already been called previously. - /// - ProgramLayout* getExistingLayout() - { - SLANG_ASSERT(m_layout); - return m_layout; - } - - /// Get the compiled code for an entry point on the target. - /// - /// If this is the first time that code generation has - /// been requested, report any errors that arise during - /// code generation to the given `sink`. - /// - IArtifact* getOrCreateEntryPointResult(Int entryPointIndex, DiagnosticSink* sink); - IArtifact* getOrCreateWholeProgramResult(DiagnosticSink* sink); - - IArtifact* getExistingWholeProgramResult() { return m_wholeProgramResult; } - /// Get the compiled code for an entry point on the target. - /// - /// This routine assumes that `getOrCreateEntryPointResult` - /// has already been called previously. - /// - IArtifact* getExistingEntryPointResult(Int entryPointIndex) - { - return m_entryPointResults[entryPointIndex]; - } - - IArtifact* _createWholeProgramResult( - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq = nullptr); - - /// Internal helper for `getOrCreateEntryPointResult`. - /// - /// This is used so that command-line and API-based - /// requests for code can bottleneck through the same place. - /// - /// Shouldn't be called directly by most code. - /// - IArtifact* _createEntryPointResult( - Int entryPointIndex, - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq = nullptr); - - RefPtr<IRModule> getOrCreateIRModuleForLayout(DiagnosticSink* sink); - - RefPtr<IRModule> getExistingIRModuleForLayout() { return m_irModuleForLayout; } - - CompilerOptionSet& getOptionSet() { return m_optionSet; } - - HLSLToVulkanLayoutOptions* getHLSLToVulkanLayoutOptions() - { - return m_targetReq->getHLSLToVulkanLayoutOptions(); - } - - bool shouldEmitSPIRVDirectly() - { - return isKhronosTarget(m_targetReq) && getOptionSet().shouldEmitSPIRVDirectly(); - } - -private: - RefPtr<IRModule> createIRModuleForLayout(DiagnosticSink* sink); - - // The program being compiled or laid out - ComponentType* m_program; - - // The target that code/layout will be generated for - TargetRequest* m_targetReq; - - // The computed layout, if it has been generated yet - RefPtr<ProgramLayout> m_layout; - - CompilerOptionSet m_optionSet; - - // Generated compile results for each entry point - // in the parent `Program` (indexing matches - // the order they are given in the `Program`) - ComPtr<IArtifact> m_wholeProgramResult; - List<ComPtr<IArtifact>> m_entryPointResults; - - RefPtr<IRModule> m_irModuleForLayout; -}; - -/// A back-end-specific object to track optional feaures/capabilities/extensions -/// that are discovered to be used by a program/kernel as part of code generation. -class ExtensionTracker : public RefObject -{ - // TODO: The existence of this type is evidence of a design/architecture problem. - // - // A better formulation of things requires a few key changes: - // - // 1. All optional capabilities need to be enumerated as part of the `CapabilitySet` - // system, so that they can be reasoned about uniformly across different targets - // and different layers of the compiler. - // - // 2. The front-end should be responsible for either or both of: - // - // * Checking that `public` or otherwise externally-visible items (declarations/definitions) - // explicitly declare the capabilities they require, and that they only ever - // make use of items that are comatible with those required capabilities. - // - // * Inferring the capabilities required by items that are not externally visible, - // and attaching those capabilities explicit as a modifier or other synthesized AST node. - // - // 3. The capabilities required by a given `ComponentType` and its entry points should be - // explicitly know-able, and they should be something we can compare to the capabilities - // of a code generation target *before* back-end code generation is started. We should be - // able to issue error messages around lacking capabilities in a way the user can understand, - // in terms of the high-level-language entities. - -public: -}; - -struct RequiredLoweringPassSet -{ - bool debugInfo; - bool resultType; - bool optionalType; - bool enumType; - bool combinedTextureSamplers; - bool reinterpret; - bool generics; - bool bindExistential; - bool autodiff; - bool derivativePyBindWrapper; - bool bitcast; - bool existentialTypeLayout; - bool bindingQuery; - bool meshOutput; - bool higherOrderFunc; - bool globalVaryingVar; - bool glslSSBO; - bool byteAddressBuffer; - bool dynamicResource; - bool dynamicResourceHeap; - bool resolveVaryingInputRef; - bool specializeStageSwitch; - bool missingReturn; - bool nonVectorCompositeSelect; -}; - -/// A context for code generation in the compiler back-end -struct CodeGenContext -{ -public: - typedef List<Index> EntryPointIndices; - - struct Shared - { - public: - Shared( - TargetProgram* targetProgram, - EntryPointIndices const& entryPointIndices, - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq) - : targetProgram(targetProgram) - , entryPointIndices(entryPointIndices) - , sink(sink) - , endToEndReq(endToEndReq) - { - } - - // Shared( - // TargetProgram* targetProgram, - // EndToEndCompileRequest* endToEndReq); - - TargetProgram* targetProgram = nullptr; - EntryPointIndices entryPointIndices; - DiagnosticSink* sink = nullptr; - EndToEndCompileRequest* endToEndReq = nullptr; - }; - - CodeGenContext(Shared* shared) - : m_shared(shared) - , m_targetFormat(shared->targetProgram->getTargetReq()->getTarget()) - , m_targetProfile(shared->targetProgram->getOptionSet().getProfile()) - { - } - - CodeGenContext( - CodeGenContext* base, - CodeGenTarget targetFormat, - ExtensionTracker* extensionTracker = nullptr) - : m_shared(base->m_shared) - , m_targetFormat(targetFormat) - , m_extensionTracker(extensionTracker) - { - } - - /// Get the diagnostic sink - DiagnosticSink* getSink() { return m_shared->sink; } - - TargetProgram* getTargetProgram() { return m_shared->targetProgram; } - - EntryPointIndices const& getEntryPointIndices() { return m_shared->entryPointIndices; } - - CodeGenTarget getTargetFormat() { return m_targetFormat; } - - ExtensionTracker* getExtensionTracker() { return m_extensionTracker; } - - TargetRequest* getTargetReq() { return getTargetProgram()->getTargetReq(); } - - CapabilitySet getTargetCaps() { return getTargetReq()->getTargetCaps(); } - - CodeGenTarget getFinalTargetFormat() { return getTargetReq()->getTarget(); } - - ComponentType* getProgram() { return getTargetProgram()->getProgram(); } - - Linkage* getLinkage() { return getProgram()->getLinkage(); } - - Session* getSession() { return getLinkage()->getSessionImpl(); } - - /// Get the source manager - SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } - - ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } - - EndToEndCompileRequest* isEndToEndCompile() { return m_shared->endToEndReq; } - - EndToEndCompileRequest* isPassThroughEnabled(); - - Count getEntryPointCount() { return getEntryPointIndices().getCount(); } - - EntryPoint* getEntryPoint(Index index) { return getProgram()->getEntryPoint(index); } - - Index getSingleEntryPointIndex() - { - SLANG_ASSERT(getEntryPointCount() == 1); - return getEntryPointIndices()[0]; - } - - // - - IRDumpOptions getIRDumpOptions(); - - bool shouldValidateIR(); - bool shouldDumpIR(); - bool shouldReportCheckpointIntermediates(); - - bool shouldTrackLiveness(); - - bool shouldDumpIntermediates(); - String getIntermediateDumpPrefix(); - - bool getUseUnknownImageFormatAsDefault(); - - bool isSpecializationDisabled(); - - bool shouldSkipSPIRVValidation(); - - SlangResult requireTranslationUnitSourceFiles(); - - // - - SlangResult emitEntryPoints(ComPtr<IArtifact>& outArtifact); - - SlangResult emitPrecompiledDownstreamIR(ComPtr<IArtifact>& outArtifact); - - void maybeDumpIntermediate(IArtifact* artifact); - - // Used to cause instructions available in precompiled blobs to be - // removed between IR linking and target source generation. - bool removeAvailableInDownstreamIR = false; - - // Determines if program level compilation like getTargetCode() or getEntryPointCode() - // should return a fully linked downstream program or just the glue SPIR-V/DXIL that - // imports and uses the precompiled SPIR-V/DXIL from constituent modules. - // This is a no-op if modules are not precompiled. - bool shouldSkipDownstreamLinking(); - - RequiredLoweringPassSet& getRequiredLoweringPassSet() { return m_requiredLoweringPassSet; } - -protected: - CodeGenTarget m_targetFormat = CodeGenTarget::Unknown; - Profile m_targetProfile; - ExtensionTracker* m_extensionTracker = nullptr; - - // To improve the performance of our backend, we will try to avoid running - // passes related to features not used in the user code. - // To do so, we will scan the IR module once, and determine which passes are needed - // based on the instructions used in the IR module. - // This will allow us to skip running passes that are not needed, without having to - // run all the passes only to find out that no work is needed. - // This is especially important for the performance of the backend, as some passes - // have an initialization cost (such as building reference graphs or DOM trees) that - // can be expensive. - RequiredLoweringPassSet m_requiredLoweringPassSet; - - /// Will output assembly as well as the artifact if appropriate for the artifact type for - /// assembly output and conversion is possible - void _dumpIntermediateMaybeWithAssembly(IArtifact* artifact); - - void _dumpIntermediate(IArtifact* artifact); - void _dumpIntermediate(const ArtifactDesc& desc, void const* data, size_t size); - - /* Emits entry point source taking into account if a pass-through or not. Uses 'targetFormat' to - determine the target (not targetReq) */ - SlangResult emitEntryPointsSource(ComPtr<IArtifact>& outArtifact); - - SlangResult emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outArtifact); - - SlangResult emitWithDownstreamForEntryPoints(ComPtr<IArtifact>& outArtifact); - - /* Determines a suitable filename to identify the input for a given entry point being compiled. - If the end-to-end compile is a pass-through case, will attempt to find the (unique) source file - pathname for the translation unit containing the entry point at `entryPointIndex. - If the compilation is not in a pass-through case, then always returns `"slang-generated"`. - @param endToEndReq The end-to-end compile request which might be using pass-through compilation - @param entryPointIndex The index of the entry point to compute a filename for. - @return the appropriate source filename */ - String calcSourcePathForEntryPoints(); - - TranslationUnitRequest* findPassThroughTranslationUnit(Int entryPointIndex); - - - SlangResult _emitEntryPoints(ComPtr<IArtifact>& outArtifact); - -private: - Shared* m_shared = nullptr; -}; - -/// A compile request that spans the front and back ends of the compiler -/// -/// This is what the command-line `slangc` uses, as well as the legacy -/// C API. It ties together the functionality of `Linkage`, -/// `FrontEndCompileRequest`, and `BackEndCompileRequest`, plus a small -/// number of additional features that primarily make sense for -/// command-line usage. -/// -class EndToEndCompileRequest : public RefObject, public slang::ICompileRequest -{ -public: - SLANG_CLASS_GUID(0xce6d2383, 0xee1b, 0x4fd7, {0xa0, 0xf, 0xb8, 0xb6, 0x33, 0x12, 0x95, 0xc8}) - - // ISlangUnknown - SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) - SLANG_OVERRIDE; - SLANG_REF_OBJECT_IUNKNOWN_ADD_REF - SLANG_REF_OBJECT_IUNKNOWN_RELEASE - - // slang::ICompileRequest - virtual SLANG_NO_THROW void SLANG_MCALL setFileSystem(ISlangFileSystem* fileSystem) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setCompileFlags(SlangCompileFlags flags) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangCompileFlags SLANG_MCALL getCompileFlags() SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDumpIntermediates(int enable) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDumpIntermediatePrefix(const char* prefix) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setEnableEffectAnnotations(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setLineDirectiveMode(SlangLineDirectiveMode mode) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setCodeGenTarget(SlangCompileTarget target) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL addCodeGenTarget(SlangCompileTarget target) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - setTargetProfile(int targetIndex, SlangProfileID profile) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetFlags(int targetIndex, SlangTargetFlags flags) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - setTargetFloatingPointMode(int targetIndex, SlangFloatingPointMode mode) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - setTargetMatrixLayoutMode(int targetIndex, SlangMatrixLayoutMode mode) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - setTargetForceGLSLScalarBufferLayout(int targetIndex, bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetForceDXLayout(int targetIndex, bool value) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - setTargetGenerateWholeProgram(int targetIndex, bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetEmbedDownstreamIR(int targetIndex, bool value) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setMatrixLayoutMode(SlangMatrixLayoutMode mode) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoLevel(SlangDebugInfoLevel level) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setOptimizationLevel(SlangOptimizationLevel level) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setOutputContainerFormat(SlangContainerFormat format) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setPassThrough(SlangPassThrough passThrough) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - setDiagnosticCallback(SlangDiagnosticCallback callback, void const* userData) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - setWriter(SlangWriterChannel channel, ISlangWriter* writer) SLANG_OVERRIDE; - virtual SLANG_NO_THROW ISlangWriter* SLANG_MCALL getWriter(SlangWriterChannel channel) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addSearchPath(const char* searchDir) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - addPreprocessorDefine(const char* key, const char* value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - processCommandLineArguments(char const* const* args, int argCount) SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL - addTranslationUnit(SlangSourceLanguage language, char const* name) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDefaultModuleName(const char* defaultModuleName) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitPreprocessorDefine( - int translationUnitIndex, - const char* key, - const char* value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - addTranslationUnitSourceFile(int translationUnitIndex, char const* path) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceString( - int translationUnitIndex, - char const* path, - char const* source) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL addLibraryReference( - const char* basePath, - const void* libData, - size_t libDataSize) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceStringSpan( - int translationUnitIndex, - char const* path, - char const* sourceBegin, - char const* sourceEnd) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceBlob( - int translationUnitIndex, - char const* path, - ISlangBlob* sourceBlob) SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL - addEntryPoint(int translationUnitIndex, char const* name, SlangStage stage) SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL addEntryPointEx( - int translationUnitIndex, - char const* name, - SlangStage stage, - int genericArgCount, - char const** genericArgs) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - setGlobalGenericArgs(int genericArgCount, char const** genericArgs) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - setTypeNameForGlobalExistentialTypeParam(int slotIndex, char const* typeName) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL setTypeNameForEntryPointExistentialTypeParam( - int entryPointIndex, - int slotIndex, - char const* typeName) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setAllowGLSLInput(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL compile() SLANG_OVERRIDE; - virtual SLANG_NO_THROW char const* SLANG_MCALL getDiagnosticOutput() SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getDiagnosticOutputBlob(ISlangBlob** outBlob) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL getDependencyFileCount() SLANG_OVERRIDE; - virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath(int index) SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL getTranslationUnitCount() SLANG_OVERRIDE; - virtual SLANG_NO_THROW char const* SLANG_MCALL getEntryPointSource(int entryPointIndex) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void const* SLANG_MCALL - getEntryPointCode(int entryPointIndex, size_t* outSize) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCodeBlob( - int entryPointIndex, - int targetIndex, - ISlangBlob** outBlob) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - getTargetCodeBlob(int targetIndex, ISlangBlob** outBlob) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - getTargetHostCallable(int targetIndex, ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void const* SLANG_MCALL getCompileRequestCode(size_t* outSize) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW ISlangMutableFileSystem* SLANG_MCALL - getCompileRequestResultAsFileSystem() SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getContainerCode(ISlangBlob** outBlob) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - loadRepro(ISlangFileSystem* fileSystem, const void* data, size_t size) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL saveRepro(ISlangBlob** outBlob) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL enableReproCapture() SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getProgram(slang::IComponentType** outProgram) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - getEntryPoint(SlangInt entryPointIndex, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - getModule(SlangInt translationUnitIndex, slang::IModule** outModule) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getSession(slang::ISession** outSession) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangReflection* SLANG_MCALL getReflection() SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setCommandLineCompilerMode() SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - addTargetCapability(SlangInt targetIndex, SlangCapabilityID capability) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - getProgramWithEntryPoints(slang::IComponentType** outProgram) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL isParameterLocationUsed( - SlangInt entryPointIndex, - SlangInt targetIndex, - SlangParameterCategory category, - SlangUInt spaceIndex, - SlangUInt registerIndex, - bool& outUsed) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - setTargetLineDirectiveMode(SlangInt targetIndex, SlangLineDirectiveMode mode) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - overrideDiagnosticSeverity(SlangInt messageID, SlangSeverity overrideSeverity) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangDiagnosticFlags SLANG_MCALL getDiagnosticFlags() SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDiagnosticFlags(SlangDiagnosticFlags flags) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoFormat(SlangDebugInfoFormat format) - SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setReportDownstreamTime(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setReportPerfBenchmark(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setSkipSPIRVValidation(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL - setTargetUseMinimumSlangOptimization(int targetIndex, bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setIgnoreCapabilityCheck(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL - getCompileTimeProfile(ISlangProfiler** compileTimeProfile, bool isClear) SLANG_OVERRIDE; - - void setTrackLiveness(bool v); - - EndToEndCompileRequest(Session* session); - - EndToEndCompileRequest(Linkage* linkage); - - ~EndToEndCompileRequest(); - - // If enabled will emit IR - bool m_emitIr = false; - - // What container format are we being asked to generate? - // If it's set to a format, the container blob will be calculated during compile - ContainerFormat m_containerFormat = ContainerFormat::None; - - /// Where the container is stored. This is calculated as part of compile if m_containerFormat is - /// set to a supported format. - ComPtr<IArtifact> m_containerArtifact; - /// Holds the container as a file system - ComPtr<ISlangMutableFileSystem> m_containerFileSystem; - - /// File system used by repro system if a file couldn't be found within the repro (or associated - /// directory) - ComPtr<ISlangFileSystem> m_reproFallbackFileSystem = - ComPtr<ISlangFileSystem>(OSFileSystem::getExtSingleton()); - - // Path to output container to - String m_containerOutputPath; - - // Should we just pass the input to another compiler? - PassThroughMode m_passThrough = PassThroughMode::None; - - /// If output should be source embedded, define the style of the embedding - SourceEmbedUtil::Style m_sourceEmbedStyle = SourceEmbedUtil::Style::None; - /// The language to be used for source embedding - SourceLanguage m_sourceEmbedLanguage = SourceLanguage::C; - /// Source embed variable name. Note may be used as a basis for names if multiple items written - String m_sourceEmbedName; - - /// Source code for the specialization arguments to use for the global specialization parameters - /// of the program. - List<String> m_globalSpecializationArgStrings; - - // Are we being driven by the command-line `slangc`, and should act accordingly? - bool m_isCommandLineCompile = false; - - String m_diagnosticOutput; - - /// A blob holding the diagnostic output - ComPtr<ISlangBlob> m_diagnosticOutputBlob; - - /// Per-entry-point information not tracked by other compile requests - class EntryPointInfo : public RefObject - { - public: - /// Source code for the specialization arguments to use for the specialization parameters of - /// the entry point. - List<String> specializationArgStrings; - }; - List<EntryPointInfo> m_entryPoints; - - /// Per-target information only needed for command-line compiles - class TargetInfo : public RefObject - { - public: - // Requested output paths for each entry point. - // An empty string indices no output desired for - // the given entry point. - Dictionary<Int, String> entryPointOutputPaths; - String wholeTargetOutputPath; - CompilerOptionSet targetOptions; - }; - Dictionary<TargetRequest*, RefPtr<TargetInfo>> m_targetInfos; - - CompilerOptionSet m_optionSetForDefaultTarget; - - CompilerOptionSet& getTargetOptionSet(TargetRequest* req); - - CompilerOptionSet& getTargetOptionSet(Index targetIndex); - - String m_dependencyOutputPath; - - /// Writes the modules in a container to the stream - SlangResult writeContainerToStream(Stream* stream); - - /// If a container format has been specified produce a container (stored in m_containerBlob) - SlangResult maybeCreateContainer(); - /// If a container has been constructed and the filename/path has contents will try to write - /// the container contents to the file - SlangResult maybeWriteContainer(const String& fileName); - - Linkage* getLinkage() { return m_linkage; } - - int addEntryPoint( - int translationUnitIndex, - String const& name, - Profile profile, - List<String> const& genericTypeNames); - - void setWriter(WriterChannel chan, ISlangWriter* writer); - ISlangWriter* getWriter(WriterChannel chan) const - { - return m_writers->getWriter(SlangWriterChannel(chan)); - } - - /// The end to end request can be passed as nullptr, if not driven by one - SlangResult executeActionsInner(); - SlangResult executeActions(); - - Session* getSession() { return m_session; } - DiagnosticSink* getSink() { return &m_sink; } - NamePool* getNamePool() { return getLinkage()->getNamePool(); } - - FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; } - - ComponentType* getUnspecializedGlobalComponentType() - { - return getFrontEndReq()->getGlobalComponentType(); - } - ComponentType* getUnspecializedGlobalAndEntryPointsComponentType() - { - return getFrontEndReq()->getGlobalAndEntryPointsComponentType(); - } - - ComponentType* getSpecializedGlobalComponentType() { return m_specializedGlobalComponentType; } - ComponentType* getSpecializedGlobalAndEntryPointsComponentType() - { - return m_specializedGlobalAndEntryPointsComponentType; - } - - ComponentType* getSpecializedEntryPointComponentType(Index index) - { - return m_specializedEntryPoints[index]; - } - - void writeArtifactToStandardOutput(IArtifact* artifact, DiagnosticSink* sink); - - void generateOutput(); - - CompilerOptionSet& getOptionSet() { return m_linkage->m_optionSet; } - -private: - String _getWholeProgramPath(TargetRequest* targetReq); - String _getEntryPointPath(TargetRequest* targetReq, Index entryPointIndex); - - /// Maybe write the artifact to the path (if set), or stdout (if there is no container or path) - SlangResult _maybeWriteArtifact(const String& path, IArtifact* artifact); - SlangResult _maybeWriteDebugArtifact( - TargetProgram* targetProgram, - const String& path, - IArtifact* artifact); - SlangResult _writeArtifact(const String& path, IArtifact* artifact); - - /// Adds any extra settings to complete a targetRequest - void _completeTargetRequest(UInt targetIndex); - - ISlangUnknown* getInterface(const Guid& guid); - - void generateOutput(ComponentType* program); - void generateOutput(TargetProgram* targetProgram); - - void init(); - - Session* m_session = nullptr; - RefPtr<Linkage> m_linkage; - DiagnosticSink m_sink; - RefPtr<FrontEndCompileRequest> m_frontEndReq; - RefPtr<ComponentType> m_specializedGlobalComponentType; - RefPtr<ComponentType> m_specializedGlobalAndEntryPointsComponentType; - List<RefPtr<ComponentType>> m_specializedEntryPoints; - - // For output - - RefPtr<StdWriters> m_writers; -}; - -/* Returns SLANG_OK if pass through support is available */ -SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough); -/* Report an error appearing from external compiler to the diagnostic sink error to the diagnostic -sink. -@param compilerName The name of the compiler the error came for (or nullptr if not known) -@param res Result associated with the error. The error code will be reported. (Can take HRESULT - -and will expand to string if known) -@param diagnostic The diagnostic string associated with the compile failure -@param sink The diagnostic sink to report to */ -void reportExternalCompileError( - const char* compilerName, - SlangResult res, - const UnownedStringSlice& diagnostic, - DiagnosticSink* sink); // -// Information about BaseType that's useful for checking literals -struct BaseTypeInfo -{ - typedef uint8_t Flags; - struct Flag - { - enum Enum : Flags - { - Signed = 0x1, - FloatingPoint = 0x2, - Integer = 0x4, - }; - }; - - SLANG_FORCE_INLINE static const BaseTypeInfo& getInfo(BaseType baseType) - { - return s_info[Index(baseType)]; - } - - static UnownedStringSlice asText(BaseType baseType); - - uint8_t sizeInBytes; ///< Size of type in bytes - Flags flags; - uint8_t baseType; - - static bool check(); - -private: - static const BaseTypeInfo s_info[Index(BaseType::CountOf)]; -}; - -class CodeGenTransitionMap -{ -public: - struct Pair - { - typedef Pair ThisType; - SLANG_FORCE_INLINE bool operator==(const ThisType& rhs) const - { - return source == rhs.source && target == rhs.target; - } - SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - SLANG_FORCE_INLINE HashCode getHashCode() const - { - return combineHash(HashCode(source), HashCode(target)); - } - - CodeGenTarget source; - CodeGenTarget target; - }; - - void removeTransition(CodeGenTarget source, CodeGenTarget target) - { - m_map.remove(Pair{source, target}); - } - void addTransition(CodeGenTarget source, CodeGenTarget target, PassThroughMode compiler) - { - SLANG_ASSERT(source != target); - m_map.set(Pair{source, target}, compiler); - } - bool hasTransition(CodeGenTarget source, CodeGenTarget target) const - { - return m_map.containsKey(Pair{source, target}); - } - PassThroughMode getTransition(CodeGenTarget source, CodeGenTarget target) const - { - const Pair pair{source, target}; - auto value = m_map.tryGetValue(pair); - return value ? *value : PassThroughMode::None; - } - -protected: - Dictionary<Pair, PassThroughMode> m_map; -}; - -class Session : public RefObject, public slang::IGlobalSession -{ -public: - SLANG_COM_INTERFACE( - 0xd6b767eb, - 0xd786, - 0x4343, - {0x2a, 0x8c, 0x6d, 0xa0, 0x3d, 0x5a, 0xb4, 0x4a}) - - SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) - SLANG_OVERRIDE; - SLANG_REF_OBJECT_IUNKNOWN_ADD_REF - SLANG_REF_OBJECT_IUNKNOWN_RELEASE - - // slang::IGlobalSession - SLANG_NO_THROW SlangResult SLANG_MCALL - createSession(slang::SessionDesc const& desc, slang::ISession** outSession) override; - SLANG_NO_THROW SlangProfileID SLANG_MCALL findProfile(char const* name) override; - SLANG_NO_THROW void SLANG_MCALL - setDownstreamCompilerPath(SlangPassThrough passThrough, char const* path) override; - SLANG_NO_THROW void SLANG_MCALL - setDownstreamCompilerPrelude(SlangPassThrough inPassThrough, char const* prelude) override; - SLANG_NO_THROW void SLANG_MCALL - getDownstreamCompilerPrelude(SlangPassThrough inPassThrough, ISlangBlob** outPrelude) override; - SLANG_NO_THROW const char* SLANG_MCALL getBuildTagString() override; - SLANG_NO_THROW SlangResult SLANG_MCALL setDefaultDownstreamCompiler( - SlangSourceLanguage sourceLanguage, - SlangPassThrough defaultCompiler) override; - SLANG_NO_THROW SlangPassThrough SLANG_MCALL - getDefaultDownstreamCompiler(SlangSourceLanguage sourceLanguage) override; - - SLANG_NO_THROW void SLANG_MCALL - setLanguagePrelude(SlangSourceLanguage inSourceLanguage, char const* prelude) override; - SLANG_NO_THROW void SLANG_MCALL - getLanguagePrelude(SlangSourceLanguage inSourceLanguage, ISlangBlob** outPrelude) override; - - SLANG_NO_THROW SlangResult SLANG_MCALL - createCompileRequest(slang::ICompileRequest** outCompileRequest) override; - - SLANG_NO_THROW void SLANG_MCALL - addBuiltins(char const* sourcePath, char const* sourceString) override; - SLANG_NO_THROW void SLANG_MCALL - setSharedLibraryLoader(ISlangSharedLibraryLoader* loader) override; - SLANG_NO_THROW ISlangSharedLibraryLoader* SLANG_MCALL getSharedLibraryLoader() override; - SLANG_NO_THROW SlangResult SLANG_MCALL - checkCompileTargetSupport(SlangCompileTarget target) override; - SLANG_NO_THROW SlangResult SLANG_MCALL - checkPassThroughSupport(SlangPassThrough passThrough) override; - - void writeCoreModuleDoc(String config); - SLANG_NO_THROW SlangResult SLANG_MCALL - compileCoreModule(slang::CompileCoreModuleFlags flags) override; - SLANG_NO_THROW SlangResult SLANG_MCALL - loadCoreModule(const void* coreModule, size_t coreModuleSizeInBytes) override; - SLANG_NO_THROW SlangResult SLANG_MCALL - saveCoreModule(SlangArchiveType archiveType, ISlangBlob** outBlob) override; - - SLANG_NO_THROW SlangResult SLANG_MCALL compileBuiltinModule( - slang::BuiltinModuleName moduleName, - slang::CompileCoreModuleFlags flags) override; - SLANG_NO_THROW SlangResult SLANG_MCALL loadBuiltinModule( - slang::BuiltinModuleName moduleName, - const void* coreModule, - size_t coreModuleSizeInBytes) override; - SLANG_NO_THROW SlangResult SLANG_MCALL saveBuiltinModule( - slang::BuiltinModuleName moduleName, - SlangArchiveType archiveType, - ISlangBlob** outBlob) override; - - SLANG_NO_THROW SlangCapabilityID SLANG_MCALL findCapability(char const* name) override; - - SLANG_NO_THROW void SLANG_MCALL setDownstreamCompilerForTransition( - SlangCompileTarget source, - SlangCompileTarget target, - SlangPassThrough compiler) override; - SLANG_NO_THROW SlangPassThrough SLANG_MCALL getDownstreamCompilerForTransition( - SlangCompileTarget source, - SlangCompileTarget target) override; - SLANG_NO_THROW void SLANG_MCALL - getCompilerElapsedTime(double* outTotalTime, double* outDownstreamTime) override - { - *outDownstreamTime = m_downstreamCompileTime; - *outTotalTime = m_totalCompileTime; - } - - SLANG_NO_THROW SlangResult SLANG_MCALL setSPIRVCoreGrammar(char const* jsonPath) override; - - SLANG_NO_THROW SlangResult SLANG_MCALL parseCommandLineArguments( - int argc, - const char* const* argv, - slang::SessionDesc* outSessionDesc, - ISlangUnknown** outAllocation) override; - - SLANG_NO_THROW SlangResult SLANG_MCALL - getSessionDescDigest(slang::SessionDesc* sessionDesc, ISlangBlob** outBlob) override; - - /// Get the downstream compiler for a transition - IDownstreamCompiler* getDownstreamCompiler(CodeGenTarget source, CodeGenTarget target); - - // This needs to be atomic not because of contention between threads as `Session` is - // *not* multithreaded, but can be used exclusively on one thread at a time. - // The need for atomic is purely for visibility. If the session is used on a different - // thread we need to be sure any changes to m_epochId are visible to this thread. - std::atomic<Index> m_epochId = 1; - - Scope* baseLanguageScope = nullptr; - Scope* coreLanguageScope = nullptr; - Scope* hlslLanguageScope = nullptr; - Scope* slangLanguageScope = nullptr; - Scope* glslLanguageScope = nullptr; - Name* glslModuleName = nullptr; - - ModuleDecl* baseModuleDecl = nullptr; - List<RefPtr<Module>> coreModules; - - SourceManager builtinSourceManager; - - SourceManager* getBuiltinSourceManager() { return &builtinSourceManager; } - - // Name pool stuff for unique-ing identifiers - - NamePool namePool; - - NamePool* getNamePool() { return &namePool; } - Name* getNameObj(String name) { return namePool.getName(name); } - Name* tryGetNameObj(String name) { return namePool.tryGetName(name); } - // - - /// This AST Builder should only be used for creating AST nodes that are global across requests - /// not doing so could lead to memory being consumed but not used. - ASTBuilder* getGlobalASTBuilder() { return globalAstBuilder; } - void finalizeSharedASTBuilder(); - - RefPtr<ASTBuilder> globalAstBuilder; - - // Generated code for core module, etc. - String coreModulePath; - - ComPtr<ISlangBlob> coreLibraryCode; - // ComPtr<ISlangBlob> slangLibraryCode; - ComPtr<ISlangBlob> hlslLibraryCode; - ComPtr<ISlangBlob> glslLibraryCode; - ComPtr<ISlangBlob> autodiffLibraryCode; - - String getCoreModulePath(); - - ComPtr<ISlangBlob> getCoreLibraryCode(); - ComPtr<ISlangBlob> getHLSLLibraryCode(); - ComPtr<ISlangBlob> getAutodiffLibraryCode(); - ComPtr<ISlangBlob> getGLSLLibraryCode(); - - void getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName); - - RefPtr<SharedASTBuilder> m_sharedASTBuilder; - - SPIRVCoreGrammarInfo& getSPIRVCoreGrammarInfo() - { - if (!spirvCoreGrammarInfo) - setSPIRVCoreGrammar(nullptr); - SLANG_ASSERT(spirvCoreGrammarInfo); - return *spirvCoreGrammarInfo; - } - RefPtr<SPIRVCoreGrammarInfo> spirvCoreGrammarInfo; - - // - - void _setSharedLibraryLoader(ISlangSharedLibraryLoader* loader); - - /// Will try to load the library by specified name (using the set loader), if not one already - /// available. - IDownstreamCompiler* getOrLoadDownstreamCompiler(PassThroughMode type, DiagnosticSink* sink); - /// Will unload the specified shared library if it's currently loaded - void resetDownstreamCompiler(PassThroughMode type); - - /// Get the prelude associated with the language - const String& getPreludeForLanguage(SourceLanguage language) - { - return m_languagePreludes[int(language)]; - } - - /// Get the built in linkage -> handy to get the core module from - Linkage* getBuiltinLinkage() const { return m_builtinLinkage; } - - Module* getBuiltinModule(slang::BuiltinModuleName builtinModuleName); - - Name* getCompletionRequestTokenName() const { return m_completionTokenName; } - - void init(); - - void addBuiltinSource( - Scope* scope, - String const& path, - ISlangBlob* sourceBlob, - Module*& outModule); - ~Session(); - - void addDownstreamCompileTime(double time) { m_downstreamCompileTime += time; } - void addTotalCompileTime(double time) { m_totalCompileTime += time; } - - ComPtr<ISlangSharedLibraryLoader> - m_sharedLibraryLoader; ///< The shared library loader (never null) - - int m_downstreamCompilerInitialized = 0; - - RefPtr<DownstreamCompilerSet> - m_downstreamCompilerSet; ///< Information about all available downstream compilers. - ComPtr<IDownstreamCompiler> m_downstreamCompilers[int( - PassThroughMode::CountOf)]; ///< A downstream compiler for a pass through - DownstreamCompilerLocatorFunc m_downstreamCompilerLocators[int(PassThroughMode::CountOf)]; - Name* m_completionTokenName = nullptr; ///< The name of a completion request token. - - /// For parsing command line options - CommandOptions m_commandOptions; - - int m_typeDictionarySize = 0; - - RefPtr<RefObject> m_typeCheckingCache; - TypeCheckingCache* getTypeCheckingCache(); - std::mutex m_typeCheckingCacheMutex; - -private: - struct BuiltinModuleInfo - { - const char* name; - Scope* languageScope; - }; - - BuiltinModuleInfo getBuiltinModuleInfo(slang::BuiltinModuleName name); - - void _initCodeGenTransitionMap(); - - SlangResult _readBuiltinModule( - ISlangFileSystem* fileSystem, - Scope* scope, - String moduleName, - Module*& outModule); - - SlangResult _loadRequest(EndToEndCompileRequest* request, const void* data, size_t size); - - /// Linkage used for all built-in (core module) code. - RefPtr<Linkage> m_builtinLinkage; - - String - m_downstreamCompilerPaths[int(PassThroughMode::CountOf)]; ///< Paths for each pass through - String m_languagePreludes[int(SourceLanguage::CountOf)]; ///< Prelude for each source language - PassThroughMode m_defaultDownstreamCompilers[int(SourceLanguage::CountOf)]; - - // Describes a conversion from one code gen target (source) to another (target) - CodeGenTransitionMap m_codeGenTransitionMap; - - double m_downstreamCompileTime = 0.0; - double m_totalCompileTime = 0.0; -}; - -const char* getBuiltinModuleNameStr(slang::BuiltinModuleName name); - void checkTranslationUnit( TranslationUnitRequest* translationUnit, LoadedModuleDictionary& loadedModules); @@ -3897,109 +206,6 @@ SlangResult passthroughDownstreamDiagnostics( IDownstreamCompiler* compiler, IArtifact* artifact); -// -// The following functions are utilties to convert between -// matching "external" (public API) and "internal" (implementation) -// types. They are favored over explicit casts because they -// help avoid making incorrect conversions (e.g., when using -// `reinterpret_cast` or C-style casts), and because they -// abstract over the conversion required for each pair of types. -// - -SLANG_FORCE_INLINE slang::IGlobalSession* asExternal(Session* session) -{ - return static_cast<slang::IGlobalSession*>(session); -} - -SLANG_FORCE_INLINE ComPtr<Session> asInternal(slang::IGlobalSession* session) -{ - Slang::Session* internalSession = nullptr; - session->queryInterface(SLANG_IID_PPV_ARGS(&internalSession)); - return ComPtr<Session>(INIT_ATTACH, static_cast<Session*>(internalSession)); -} - -SLANG_FORCE_INLINE slang::ISession* asExternal(Linkage* linkage) -{ - return static_cast<slang::ISession*>(linkage); -} - -SLANG_FORCE_INLINE Module* asInternal(slang::IModule* module) -{ - return static_cast<Module*>(module); -} - -SLANG_FORCE_INLINE slang::IModule* asExternal(Module* module) -{ - return static_cast<slang::IModule*>(module); -} - -ComponentType* asInternal(slang::IComponentType* inComponentType); - -SLANG_FORCE_INLINE slang::IComponentType* asExternal(ComponentType* componentType) -{ - return static_cast<slang::IComponentType*>(componentType); -} - -SLANG_FORCE_INLINE slang::ProgramLayout* asExternal(ProgramLayout* programLayout) -{ - return (slang::ProgramLayout*)programLayout; -} - -SLANG_FORCE_INLINE Type* asInternal(slang::TypeReflection* type) -{ - return reinterpret_cast<Type*>(type); -} - -SLANG_FORCE_INLINE slang::TypeReflection* asExternal(Type* type) -{ - return reinterpret_cast<slang::TypeReflection*>(type); -} - -SLANG_FORCE_INLINE DeclRef<Decl> asInternal(slang::GenericReflection* generic) -{ - return DeclRef<Decl>(reinterpret_cast<DeclRefBase*>(generic)); -} - -SLANG_FORCE_INLINE slang::GenericReflection* asExternal(DeclRef<Decl> generic) -{ - return reinterpret_cast<slang::GenericReflection*>(generic.declRefBase); -} - -SLANG_FORCE_INLINE TypeLayout* asInternal(slang::TypeLayoutReflection* type) -{ - return reinterpret_cast<TypeLayout*>(type); -} - -SLANG_FORCE_INLINE slang::TypeLayoutReflection* asExternal(TypeLayout* type) -{ - return reinterpret_cast<slang::TypeLayoutReflection*>(type); -} - -SLANG_FORCE_INLINE SlangCompileRequest* asExternal(EndToEndCompileRequest* request) -{ - return static_cast<SlangCompileRequest*>(request); -} - -SLANG_FORCE_INLINE EndToEndCompileRequest* asInternal(SlangCompileRequest* request) -{ - // Converts to the internal type -- does a runtime type check through queryInterfae - SLANG_ASSERT(request); - EndToEndCompileRequest* endToEndRequest = nullptr; - // NOTE! We aren't using to access an interface, so *doesn't* return with a refcount - request->queryInterface(SLANG_IID_PPV_ARGS(&endToEndRequest)); - SLANG_ASSERT(endToEndRequest); - return endToEndRequest; -} - -SLANG_FORCE_INLINE SlangCompileTarget asExternal(CodeGenTarget target) -{ - return (SlangCompileTarget)target; -} - -SLANG_FORCE_INLINE SlangSourceLanguage asExternal(SourceLanguage sourceLanguage) -{ - return (SlangSourceLanguage)sourceLanguage; -} // helpers for error/warning reporting enum class DiagnosticCategory diff --git a/source/slang/slang-container-pool.cpp b/source/slang/slang-container-pool.cpp deleted file mode 100644 index e69de29bb..000000000 --- a/source/slang/slang-container-pool.cpp +++ /dev/null diff --git a/source/slang/slang-emit-dependency-file.cpp b/source/slang/slang-emit-dependency-file.cpp new file mode 100644 index 000000000..a482b08de --- /dev/null +++ b/source/slang/slang-emit-dependency-file.cpp @@ -0,0 +1,125 @@ +// slang-emit-dependency-file.cpp +#include "slang-emit-dependency-file.h" + +#include "slang-compiler.h" + +namespace Slang +{ + +static void _writeString(Stream& stream, const char* string) +{ + stream.write(string, strlen(string)); +} + +static void _escapeDependencyString(const char* string, StringBuilder& outBuilder) +{ + // make has unusual escaping rules, but we only care about characters that are acceptable in a + // path + for (const char* p = string; *p; ++p) + { + char c = *p; + switch (c) + { + case ' ': + case ':': + case '#': + case '[': + case ']': + case '\\': + outBuilder.appendChar('\\'); + break; + + case '$': + outBuilder.appendChar('$'); + break; + } + + outBuilder.appendChar(c); + } +} + +// Writes a line to the file stream, formatted like this: +// <output-file>: <dependency-file> <dependency-file...> +static void _writeDependencyStatement( + Stream& stream, + EndToEndCompileRequest* compileRequest, + const String& outputPath) +{ + if (outputPath.getLength() == 0) + return; + + StringBuilder builder; + _escapeDependencyString(outputPath.begin(), builder); + _writeString(stream, builder.begin()); + _writeString(stream, ": "); + + int dependencyCount = compileRequest->getDependencyFileCount(); + for (int dependencyIndex = 0; dependencyIndex < dependencyCount; ++dependencyIndex) + { + builder.clear(); + _escapeDependencyString(compileRequest->getDependencyFilePath(dependencyIndex), builder); + _writeString(stream, builder.begin()); + _writeString(stream, (dependencyIndex + 1 < dependencyCount) ? " " : "\n"); + } +} + +// Writes a file with dependency info, with one line in the output file per compile product. +SlangResult writeDependencyFile(EndToEndCompileRequest* compileRequest) +{ + if (compileRequest->m_dependencyOutputPath.getLength() == 0) + return SLANG_OK; + + FileStream stream; + SLANG_RETURN_ON_FAIL(stream.init( + compileRequest->m_dependencyOutputPath, + FileMode::Create, + FileAccess::Write, + FileShare::ReadWrite)); + + auto linkage = compileRequest->getLinkage(); + auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); + + // Iterate over all the targets and their outputs + for (const auto& targetReq : linkage->targets) + { + if (compileRequest->getTargetOptionSet(targetReq).getBoolOption( + CompilerOptionName::GenerateWholeProgram)) + { + RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; + if (compileRequest->m_targetInfos.tryGetValue(targetReq, targetInfo)) + { + _writeDependencyStatement( + stream, + compileRequest, + targetInfo->wholeTargetOutputPath); + } + } + else + { + Index entryPointCount = program->getEntryPointCount(); + for (Index entryPointIndex = 0; entryPointIndex < entryPointCount; ++entryPointIndex) + { + RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; + if (compileRequest->m_targetInfos.tryGetValue(targetReq, targetInfo)) + { + String outputPath; + if (targetInfo->entryPointOutputPaths.tryGetValue(entryPointIndex, outputPath)) + { + _writeDependencyStatement(stream, compileRequest, outputPath); + } + } + } + } + } + + // When the output is a binary module, linkage->targets can be empty. So + // we need to do their dependencies separately. + if (compileRequest->m_containerFormat == ContainerFormat::SlangModule) + { + _writeDependencyStatement(stream, compileRequest, compileRequest->m_containerOutputPath); + } + + return SLANG_OK; +} + +} // namespace Slang diff --git a/source/slang/slang-emit-dependency-file.h b/source/slang/slang-emit-dependency-file.h new file mode 100644 index 000000000..27e584508 --- /dev/null +++ b/source/slang/slang-emit-dependency-file.h @@ -0,0 +1,20 @@ +// slang-emit-dependency-file.h +#pragma once + +// +// This file defines the interface for emitting a +// dependency file (in the same format used by `make`, +// `gcc`, and various other tools) based on a compile +// request using the `slangc` tool. +// + +#include <slang.h> + +namespace Slang +{ +class EndToEndCompileRequest; + +SlangResult writeDependencyFile(EndToEndCompileRequest* compileRequest); + + +} // namespace Slang diff --git a/source/slang/slang-end-to-end-request.cpp b/source/slang/slang-end-to-end-request.cpp new file mode 100644 index 000000000..d0f0a6f53 --- /dev/null +++ b/source/slang/slang-end-to-end-request.cpp @@ -0,0 +1,1999 @@ +// slang-end-to-end-request.cpp +#include "slang-end-to-end-request.h" + +#include "compiler-core/slang-pretty-writer.h" +#include "core/slang-memory-file-system.h" +#include "core/slang-performance-profiler.h" +#include "slang-check-impl.h" +#include "slang-compiler.h" +#include "slang-emit-dependency-file.h" +#include "slang-module-library.h" +#include "slang-options.h" +#include "slang-reflection-json.h" +#include "slang-repro.h" +#include "slang-serialize-container.h" + +// TODO: The "artifact" system is a scourge. +#include "compiler-core/slang-artifact-associated-impl.h" +#include "compiler-core/slang-artifact-container-util.h" +#include "compiler-core/slang-artifact-desc-util.h" +#include "compiler-core/slang-artifact-impl.h" +#include "compiler-core/slang-artifact-util.h" +#include "slang-artifact-output-util.h" + +namespace Slang +{ + +EndToEndCompileRequest::EndToEndCompileRequest(Session* session) + : m_session(session), m_sink(nullptr, Lexer::sourceLocationLexer) +{ + RefPtr<ASTBuilder> astBuilder( + new ASTBuilder(session->m_sharedASTBuilder, "EndToEnd::Linkage::astBuilder")); + m_linkage = new Linkage(session, astBuilder, session->getBuiltinLinkage()); + init(); +} + +EndToEndCompileRequest::EndToEndCompileRequest(Linkage* linkage) + : m_session(linkage->getSessionImpl()) + , m_linkage(linkage) + , m_sink(nullptr, Lexer::sourceLocationLexer) +{ + init(); +} + +void EndToEndCompileRequest::init() +{ + m_sink.setSourceManager(m_linkage->getSourceManager()); + + m_writers = new StdWriters; + + // Set all the default writers + for (int i = 0; i < int(WriterChannel::CountOf); ++i) + { + setWriter(WriterChannel(i), nullptr); + } + + m_frontEndReq = new FrontEndCompileRequest(getLinkage(), m_writers, getSink()); +} + +EndToEndCompileRequest::~EndToEndCompileRequest() +{ + // Flush any writers associated with the request + m_writers->flushWriters(); + + m_linkage.setNull(); + m_frontEndReq.setNull(); +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +EndToEndCompileRequest::queryInterface(SlangUUID const& uuid, void** outObject) +{ + if (uuid == EndToEndCompileRequest::getTypeGuid()) + { + // Special case to cast directly into internal type + // NOTE! No addref(!) + *outObject = this; + return SLANG_OK; + } + + if (uuid == ISlangUnknown::getTypeGuid() && uuid == ICompileRequest::getTypeGuid()) + { + addReference(); + *outObject = static_cast<slang::ICompileRequest*>(this); + return SLANG_OK; + } + + return SLANG_E_NO_INTERFACE; +} + +// Try to infer a single common source language for a request +static SourceLanguage inferSourceLanguage(FrontEndCompileRequest* request) +{ + SourceLanguage language = SourceLanguage::Unknown; + for (auto& translationUnit : request->translationUnits) + { + // Allow any other language to overide Slang as a choice + if (language == SourceLanguage::Unknown || language == SourceLanguage::Slang) + { + language = translationUnit->sourceLanguage; + } + else if (language == translationUnit->sourceLanguage) + { + // same language as we currently have, so keep going + } + else + { + // we found a mismatch, so inference fails + return SourceLanguage::Unknown; + } + } + return language; +} + +SlangResult EndToEndCompileRequest::executeActionsInner() +{ + SLANG_PROFILE_SECTION(endToEndActions); + // If no code-generation target was specified, then try to infer one from the source language, + // just to make sure we can do something reasonable when invoked from the command line. + // + // TODO: This logic should be moved into `options.cpp` or somewhere else + // specific to the command-line tool. + // + if (getLinkage()->targets.getCount() == 0) + { + auto language = inferSourceLanguage(getFrontEndReq()); + switch (language) + { + case SourceLanguage::HLSL: + getLinkage()->addTarget(CodeGenTarget::DXBytecode); + break; + + case SourceLanguage::GLSL: + getLinkage()->addTarget(CodeGenTarget::SPIRV); + break; + + default: + break; + } + } + + // Update compiler settings in target requests. + for (auto target : getLinkage()->targets) + target->getOptionSet().inheritFrom(getOptionSet()); + m_frontEndReq->optionSet = getOptionSet(); + + // We only do parsing and semantic checking if we *aren't* doing + // a pass-through compilation. + // + if (m_passThrough == PassThroughMode::None) + { + SLANG_RETURN_ON_FAIL(getFrontEndReq()->executeActionsInner()); + } + + if (getOptionSet().getBoolOption(CompilerOptionName::PreprocessorOutput)) + { + return SLANG_OK; + } + + // If command line specifies to skip codegen, we exit here. + // Note: this is a debugging option. + // + if (getOptionSet().getBoolOption(CompilerOptionName::SkipCodeGen)) + { + // We will use the program (and matching layout information) + // that was computed in the front-end for all subsequent + // reflection queries, etc. + // + m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); + m_specializedGlobalAndEntryPointsComponentType = + getUnspecializedGlobalAndEntryPointsComponentType(); + m_specializedEntryPoints = getFrontEndReq()->getUnspecializedEntryPoints(); + + SLANG_RETURN_ON_FAIL(maybeCreateContainer()); + + SLANG_RETURN_ON_FAIL(maybeWriteContainer(m_containerOutputPath)); + + return SLANG_OK; + } + + // If requested, attempt to compile the translation unit all the way down to the target + // language(s) and stash the result blobs in IR. + for (auto target : getLinkage()->targets) + { + SlangCompileTarget targetEnum = SlangCompileTarget(target->getTarget()); + if (target->getOptionSet().getBoolOption(CompilerOptionName::EmbedDownstreamIR)) + { + auto frontEndReq = getFrontEndReq(); + + for (auto translationUnit : frontEndReq->translationUnits) + { + SLANG_RETURN_ON_FAIL( + translationUnit->getModule()->precompileForTarget(targetEnum, nullptr)); + + if (frontEndReq->optionSet.shouldDumpIR()) + { + DiagnosticSinkWriter writer(frontEndReq->getSink()); + + dumpIR( + translationUnit->getModule()->getIRModule(), + frontEndReq->m_irDumpOptions, + "PRECOMPILE_FOR_TARGET_COMPLETE_ALL", + frontEndReq->getSourceManager(), + &writer); + + dumpIR( + translationUnit->getModule()->getIRModule()->getModuleInst(), + frontEndReq->m_irDumpOptions, + frontEndReq->getSourceManager(), + &writer); + } + } + } + } + + // If codegen is enabled, we need to move along to + // apply any generic specialization that the user asked for. + // + if (m_passThrough == PassThroughMode::None) + { + m_specializedGlobalComponentType = createSpecializedGlobalComponentType(this); + if (getSink()->getErrorCount() != 0) + return SLANG_FAIL; + + m_specializedGlobalAndEntryPointsComponentType = + createSpecializedGlobalAndEntryPointsComponentType(this, m_specializedEntryPoints); + if (getSink()->getErrorCount() != 0) + return SLANG_FAIL; + + // For each code generation target, we will generate specialized + // parameter binding information (taking global generic + // arguments into account at this time). + // + for (auto targetReq : getLinkage()->targets) + { + auto targetProgram = + m_specializedGlobalAndEntryPointsComponentType->getTargetProgram(targetReq); + targetProgram->getOrCreateLayout(getSink()); + } + if (getSink()->getErrorCount() != 0) + return SLANG_FAIL; + } + else + { + // We need to create dummy `EntryPoint` objects + // to make sure that the logic in `generateOutput` + // sees something worth processing. + // + List<RefPtr<ComponentType>> dummyEntryPoints; + for (auto entryPointReq : getFrontEndReq()->getEntryPointReqs()) + { + RefPtr<EntryPoint> dummyEntryPoint = EntryPoint::createDummyForPassThrough( + getLinkage(), + entryPointReq->getName(), + entryPointReq->getProfile()); + + dummyEntryPoints.add(dummyEntryPoint); + } + + RefPtr<ComponentType> composedProgram = + CompositeComponentType::create(getLinkage(), dummyEntryPoints); + + m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); + m_specializedGlobalAndEntryPointsComponentType = composedProgram; + m_specializedEntryPoints = getFrontEndReq()->getUnspecializedEntryPoints(); + } + + // Generate output code, in whatever format was requested + generateOutput(); + if (getSink()->getErrorCount() != 0) + return SLANG_FAIL; + + return SLANG_OK; +} + +// Act as expected of the API-based compiler +SlangResult EndToEndCompileRequest::executeActions() +{ + SlangResult res = executeActionsInner(); + + m_diagnosticOutput = getSink()->outputBuffer.produceString(); + return res; +} + + +static ISlangWriter* _getDefaultWriter(WriterChannel chan) +{ + static FileWriter stdOut(stdout, WriterFlag::IsStatic | WriterFlag::IsUnowned); + static FileWriter stdError(stderr, WriterFlag::IsStatic | WriterFlag::IsUnowned); + static NullWriter nullWriter(WriterFlag::IsStatic | WriterFlag::IsConsole); + + switch (chan) + { + case WriterChannel::StdError: + return &stdError; + case WriterChannel::StdOutput: + return &stdOut; + case WriterChannel::Diagnostic: + return &nullWriter; + default: + { + SLANG_ASSERT(!"Unknown type"); + return &stdError; + } + } +} + +void EndToEndCompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) +{ + // If the user passed in null, we will use the default writer on that channel + m_writers->setWriter(SlangWriterChannel(chan), writer ? writer : _getDefaultWriter(chan)); + + // For diagnostic output, if the user passes in nullptr, we set on m_sink.writer as that enables + // buffering on DiagnosticSink + if (chan == WriterChannel::Diagnostic) + { + m_sink.writer = writer; + } +} + +void EndToEndCompileRequest::writeArtifactToStandardOutput( + IArtifact* artifact, + DiagnosticSink* sink) +{ + // If it's host callable it's not available to write to output + if (isDerivedFrom(artifact->getDesc().kind, ArtifactKind::HostCallable)) + { + return; + } + + auto session = getSession(); + ArtifactOutputUtil::maybeConvertAndWrite( + session, + artifact, + sink, + toSlice("stdout"), + getWriter(WriterChannel::StdOutput)); +} + +String EndToEndCompileRequest::_getWholeProgramPath(TargetRequest* targetReq) +{ + RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; + if (m_targetInfos.tryGetValue(targetReq, targetInfo)) + { + return targetInfo->wholeTargetOutputPath; + } + return String(); +} + +String EndToEndCompileRequest::_getEntryPointPath(TargetRequest* targetReq, Index entryPointIndex) +{ + // It is possible that we are dynamically discovering entry + // points (using `[shader(...)]` attributes), so that there + // might be entry points added to the program that did not + // get paths specified via command-line options. + // + RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; + if (m_targetInfos.tryGetValue(targetReq, targetInfo)) + { + String outputPath; + if (targetInfo->entryPointOutputPaths.tryGetValue(entryPointIndex, outputPath)) + { + return outputPath; + } + } + + return String(); +} + +SlangResult EndToEndCompileRequest::_writeArtifact(const String& path, IArtifact* artifact) +{ + if (path.getLength() > 0) + { + SLANG_RETURN_ON_FAIL(ArtifactOutputUtil::writeToFile(artifact, getSink(), path)); + } + else if (m_containerFormat == ContainerFormat::None) + { + // If we aren't writing to a container and we didn't write to a file, we can output to + // standard output + writeArtifactToStandardOutput(artifact, getSink()); + } + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::_maybeWriteArtifact(const String& path, IArtifact* artifact) +{ + // We don't have to do anything if there is no artifact + if (!artifact) + { + return SLANG_OK; + } + + // If embedding is enabled... + if (m_sourceEmbedStyle != SourceEmbedUtil::Style::None) + { + SourceEmbedUtil::Options options; + + options.style = m_sourceEmbedStyle; + options.variableName = m_sourceEmbedName; + options.language = (SlangSourceLanguage)m_sourceEmbedLanguage; + + ComPtr<IArtifact> embeddedArtifact; + SLANG_RETURN_ON_FAIL(SourceEmbedUtil::createEmbedded(artifact, options, embeddedArtifact)); + + if (!embeddedArtifact) + { + return SLANG_FAIL; + } + SLANG_RETURN_ON_FAIL( + _writeArtifact(SourceEmbedUtil::getPath(path, options), embeddedArtifact)); + return SLANG_OK; + } + else + { + SLANG_RETURN_ON_FAIL(_writeArtifact(path, artifact)); + } + + return SLANG_OK; +} + +// These helper functions are used by the -separate-debug-info command line +// arg to extract the associated artifact containing the debug SPIRV data +// and save it to a file with a .dbg.spv extension. +static String _getDebugSpvPath(const String& basePath) +{ + // Find the last occurrence of ".spv" at the end of the string. + static const char ext[] = ".spv"; + static const char dbgExt[] = ".dbg.spv"; + Index extLen = 4; + if (basePath.getLength() >= extLen && basePath.endsWith(ext)) + { + // Replace the ".spv" extension with ".dbg.spv" + String prefix = String(basePath.subString(0, basePath.getLength() - extLen)); + return prefix + dbgExt; + } + // If it doesn't end with .spv, just append .dbg.spv + return basePath + dbgExt; +} + +SlangResult EndToEndCompileRequest::_maybeWriteDebugArtifact( + TargetProgram* targetProgram, + const String& path, + IArtifact* artifact) +{ + if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::EmitSeparateDebug)) + { + const auto dbgArtifact = getSeparateDbgArtifact(artifact); + // Check if a debug artifact was actually created (only for SPIR-V targets) + if (dbgArtifact) + { + // The artifact's name may have been set to the debug build id hash, use + // it as the filename if it exists. + String dbgPath = dbgArtifact->getName(); + if (dbgPath.getLength() == 0) + dbgPath = _getDebugSpvPath(path); + else + dbgPath.append(".dbg.spv"); + return _maybeWriteArtifact(dbgPath, dbgArtifact); + } + // If no debug artifact exists (e.g., for non-SPIR-V targets), just silently succeed + // The warning about unsupported targets is already issued during option parsing + } + + return SLANG_OK; +} + +void EndToEndCompileRequest::generateOutput(TargetProgram* targetProgram) +{ + auto program = targetProgram->getProgram(); + + // Generate target code any entry points that + // have been requested for compilation. + auto entryPointCount = program->getEntryPointCount(); + if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram)) + { + targetProgram->_createWholeProgramResult(getSink(), this); + } + else + { + for (Index ii = 0; ii < entryPointCount; ++ii) + { + targetProgram->_createEntryPointResult(ii, getSink(), this); + } + } +} + +bool _shouldWriteSourceLocs(Linkage* linkage) +{ + // If debug information or source manager are not avaiable we can't/shouldn't write out locs + if (linkage->m_optionSet.getEnumOption<DebugInfoLevel>(CompilerOptionName::DebugInformation) == + DebugInfoLevel::None || + linkage->getSourceManager() == nullptr) + { + return false; + } + + // Otherwise we do want to write out the locs + return true; +} + +SlangResult EndToEndCompileRequest::writeContainerToStream(Stream* stream) +{ + auto linkage = getLinkage(); + + // Set up options + SerialContainerUtil::WriteOptions options; + + // If debug information is enabled, enable writing out source locs + if (_shouldWriteSourceLocs(linkage)) + { + options.sourceManagerToUseWhenSerializingSourceLocs = linkage->getSourceManager(); + } + + SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(this, options, stream)); + + return SLANG_OK; +} + +static IBoxValue<SourceMap>* _getObfuscatedSourceMap(TranslationUnitRequest* translationUnit) +{ + if (auto module = translationUnit->getModule()) + { + if (auto irModule = module->getIRModule()) + { + return irModule->getObfuscatedSourceMap(); + } + } + return nullptr; +} + +SlangResult EndToEndCompileRequest::maybeCreateContainer() +{ + m_containerArtifact.setNull(); + + List<ComPtr<IArtifact>> artifacts; + + auto linkage = getLinkage(); + + auto program = getSpecializedGlobalAndEntryPointsComponentType(); + + for (auto targetReq : linkage->targets) + { + auto targetProgram = program->getTargetProgram(targetReq); + + if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram)) + { + if (auto artifact = targetProgram->getExistingWholeProgramResult()) + { + if (!targetProgram->getOptionSet().getBoolOption( + CompilerOptionName::EmbedDownstreamIR)) + { + artifacts.add(ComPtr<IArtifact>(artifact)); + } + } + } + else + { + Index entryPointCount = program->getEntryPointCount(); + for (Index ee = 0; ee < entryPointCount; ++ee) + { + if (auto artifact = targetProgram->getExistingEntryPointResult(ee)) + { + artifacts.add(ComPtr<IArtifact>(artifact)); + } + } + } + } + + // If IR emitting is enabled, add IR to the artifacts + if (m_emitIr && (m_containerFormat == ContainerFormat::SlangModule)) + { + OwnedMemoryStream stream(FileAccess::Write); + SlangResult res = writeContainerToStream(&stream); + if (SLANG_FAILED(res)) + { + getSink()->diagnose(SourceLoc(), Diagnostics::unableToCreateModuleContainer); + return res; + } + + // Need to turn into a blob + List<uint8_t> blobData; + stream.swapContents(blobData); + + auto containerBlob = ListBlob::moveCreate(blobData); + + auto irArtifact = Artifact::create(ArtifactDesc::make( + Artifact::Kind::CompileBinary, + ArtifactPayload::SlangIR, + ArtifactStyle::Unknown)); + irArtifact->addRepresentationUnknown(containerBlob); + + // Add the IR artifact + artifacts.add(irArtifact); + } + + // If there is only one artifact we can use that as the container + if (artifacts.getCount() == 1) + { + m_containerArtifact = artifacts[0]; + } + else + { + m_containerArtifact = ArtifactUtil::createArtifact( + ArtifactDesc::make(ArtifactKind::Container, ArtifactPayload::CompileResults)); + + for (IArtifact* childArtifact : artifacts) + { + m_containerArtifact->addChild(childArtifact); + } + } + + // Get all of the source obfuscated source maps and add those + if (m_containerArtifact) + { + auto frontEndReq = getFrontEndReq(); + + for (auto translationUnit : frontEndReq->translationUnits) + { + // Hmmm do I have to therefore add a map for all translation units(!) + // I guess this is okay in so far as an association can always be looked up by name + if (auto sourceMap = _getObfuscatedSourceMap(translationUnit)) + { + auto artifactDesc = ArtifactDesc::make( + ArtifactKind::Json, + ArtifactPayload::SourceMap, + ArtifactStyle::Obfuscated); + + // Create the source map artifact + auto sourceMapArtifact = + Artifact::create(artifactDesc, sourceMap->get().m_file.getUnownedSlice()); + + // Add the repesentation + sourceMapArtifact->addRepresentation(sourceMap); + + // Associate with the container + m_containerArtifact->addAssociated(sourceMapArtifact); + } + } + } + + return SLANG_OK; +} + +CompilerOptionSet& EndToEndCompileRequest::getTargetOptionSet(TargetRequest* req) +{ + return req->getOptionSet(); +} + +CompilerOptionSet& EndToEndCompileRequest::getTargetOptionSet(Index targetIndex) +{ + return m_linkage->targets[targetIndex]->getOptionSet(); +} + +SlangResult EndToEndCompileRequest::maybeWriteContainer(const String& fileName) +{ + // If there is no container, or filename, don't write anything + if (fileName.getLength() == 0 || !m_containerArtifact) + { + return SLANG_OK; + } + + // Filter the containerArtifact into things that can be written + ComPtr<IArtifact> writeArtifact; + SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::filter(m_containerArtifact, writeArtifact)); + + // Only write if there is something to write + if (writeArtifact) + { + SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::writeContainer(writeArtifact, fileName)); + } + + return SLANG_OK; +} + +void EndToEndCompileRequest::generateOutput(ComponentType* program) +{ + // When dynamic dispatch is disabled, the program must + // be fully specialized by now. So we check if we still + // have unspecialized generic/existential parameters, + // and report them as an error. + // + auto specializationParamCount = program->getSpecializationParamCount(); + if (getOptionSet().getBoolOption(CompilerOptionName::DisableDynamicDispatch) && + specializationParamCount != 0) + { + auto sink = getSink(); + + for (Index ii = 0; ii < specializationParamCount; ++ii) + { + auto specializationParam = program->getSpecializationParam(ii); + if (auto decl = as<Decl>(specializationParam.object)) + { + sink->diagnose( + specializationParam.loc, + Diagnostics::specializationParameterOfNameNotSpecialized, + decl); + } + else if (auto type = as<Type>(specializationParam.object)) + { + sink->diagnose( + specializationParam.loc, + Diagnostics::specializationParameterOfNameNotSpecialized, + type); + } + else + { + sink->diagnose( + specializationParam.loc, + Diagnostics::specializationParameterNotSpecialized); + } + } + + return; + } + + + // Go through the code-generation targets that the user + // has specified, and generate code for each of them. + // + auto linkage = getLinkage(); + for (auto targetReq : linkage->targets) + { + if (targetReq->getOptionSet().getBoolOption(CompilerOptionName::EmbedDownstreamIR)) + continue; + + auto targetProgram = program->getTargetProgram(targetReq); + generateOutput(targetProgram); + } +} + +void EndToEndCompileRequest::generateOutput() +{ + SLANG_PROFILE; + generateOutput(getSpecializedGlobalAndEntryPointsComponentType()); + + // If we are in command-line mode, we might be expected to actually + // write output to one or more files here. + + if (m_isCommandLineCompile && m_containerFormat == ContainerFormat::None) + { + auto linkage = getLinkage(); + auto program = getSpecializedGlobalAndEntryPointsComponentType(); + + for (auto targetReq : linkage->targets) + { + auto targetProgram = program->getTargetProgram(targetReq); + + if (targetProgram->getOptionSet().getBoolOption( + CompilerOptionName::GenerateWholeProgram)) + { + if (const auto artifact = targetProgram->getExistingWholeProgramResult()) + { + const auto path = _getWholeProgramPath(targetReq); + + _maybeWriteArtifact(path, artifact); + + // If we are compiling separate debug info, check for the additional + // SPIRV artifact and write that if needed. + _maybeWriteDebugArtifact(targetProgram, path, artifact); + } + } + else + { + Index entryPointCount = program->getEntryPointCount(); + for (Index ee = 0; ee < entryPointCount; ++ee) + { + if (const auto artifact = targetProgram->getExistingEntryPointResult(ee)) + { + const auto path = _getEntryPointPath(targetReq, ee); + + _maybeWriteArtifact(path, artifact); + + // If we are compiling separate debug info, check for the additional + // SPIRV artifact and write that if needed. + _maybeWriteDebugArtifact(targetProgram, path, artifact); + } + } + } + } + } + + // Maybe create the container + maybeCreateContainer(); + + // If it's a command line compile we may need to write the container to a file + if (m_isCommandLineCompile) + { + // TODO(JS): + // We could write the container into a source embedded format potentially + + maybeWriteContainer(m_containerOutputPath); + + writeDependencyFile(this); + } +} + +void EndToEndCompileRequest::setFileSystem(ISlangFileSystem* fileSystem) +{ + getLinkage()->setFileSystem(fileSystem); +} + +void EndToEndCompileRequest::setCompileFlags(SlangCompileFlags flags) +{ + if (flags & SLANG_COMPILE_FLAG_NO_MANGLING) + getOptionSet().set(CompilerOptionName::NoMangle, true); + if (flags & SLANG_COMPILE_FLAG_NO_CODEGEN) + getOptionSet().set(CompilerOptionName::SkipCodeGen, true); + if (flags & SLANG_COMPILE_FLAG_OBFUSCATE) + getOptionSet().set(CompilerOptionName::Obfuscate, true); +} + +SlangCompileFlags EndToEndCompileRequest::getCompileFlags() +{ + SlangCompileFlags result = 0; + if (getOptionSet().getBoolOption(CompilerOptionName::NoMangle)) + result |= SLANG_COMPILE_FLAG_NO_MANGLING; + if (getOptionSet().getBoolOption(CompilerOptionName::SkipCodeGen)) + result |= SLANG_COMPILE_FLAG_NO_CODEGEN; + if (getOptionSet().getBoolOption(CompilerOptionName::Obfuscate)) + result |= SLANG_COMPILE_FLAG_OBFUSCATE; + return result; +} + +void EndToEndCompileRequest::setDumpIntermediates(int enable) +{ + getOptionSet().set(CompilerOptionName::DumpIntermediates, enable); +} + +void EndToEndCompileRequest::setTrackLiveness(bool v) +{ + getOptionSet().set(CompilerOptionName::TrackLiveness, v); +} + +void EndToEndCompileRequest::setDumpIntermediatePrefix(const char* prefix) +{ + getOptionSet().set(CompilerOptionName::DumpIntermediatePrefix, String(prefix)); +} + +void EndToEndCompileRequest::setLineDirectiveMode(SlangLineDirectiveMode mode) +{ + getOptionSet().set(CompilerOptionName::LineDirectiveMode, mode); +} + +void EndToEndCompileRequest::setCommandLineCompilerMode() +{ + m_isCommandLineCompile = true; + + // legacy slangc tool defaults to column major layout. + if (!getOptionSet().hasOption(CompilerOptionName::MatrixLayoutRow)) + getOptionSet().setMatrixLayoutMode(kMatrixLayoutMode_ColumnMajor); +} + +void EndToEndCompileRequest::_completeTargetRequest(UInt targetIndex) +{ + auto linkage = getLinkage(); + + TargetRequest* targetRequest = linkage->targets[Index(targetIndex)]; + + targetRequest->getOptionSet().inheritFrom(getLinkage()->m_optionSet); + targetRequest->getOptionSet().inheritFrom(m_optionSetForDefaultTarget); +} + +void EndToEndCompileRequest::setCodeGenTarget(SlangCompileTarget target) +{ + auto linkage = getLinkage(); + linkage->targets.clear(); + const auto targetIndex = linkage->addTarget(CodeGenTarget(target)); + SLANG_ASSERT(targetIndex == 0); + _completeTargetRequest(0); +} + +int EndToEndCompileRequest::addCodeGenTarget(SlangCompileTarget target) +{ + const auto targetIndex = getLinkage()->addTarget(CodeGenTarget(target)); + _completeTargetRequest(targetIndex); + return int(targetIndex); +} + +void EndToEndCompileRequest::setTargetProfile(int targetIndex, SlangProfileID profile) +{ + getTargetOptionSet(targetIndex).setProfile(Profile(profile)); +} + +void EndToEndCompileRequest::setTargetFlags(int targetIndex, SlangTargetFlags flags) +{ + getTargetOptionSet(targetIndex).setTargetFlags(flags); +} + +void EndToEndCompileRequest::setTargetForceGLSLScalarBufferLayout(int targetIndex, bool value) +{ + getTargetOptionSet(targetIndex).set(CompilerOptionName::GLSLForceScalarLayout, value); +} + +void EndToEndCompileRequest::setTargetForceDXLayout(int targetIndex, bool value) +{ + getTargetOptionSet(targetIndex).set(CompilerOptionName::ForceDXLayout, value); +} + +void EndToEndCompileRequest::setTargetFloatingPointMode( + int targetIndex, + SlangFloatingPointMode mode) +{ + getTargetOptionSet(targetIndex) + .set(CompilerOptionName::FloatingPointMode, FloatingPointMode(mode)); +} + +void EndToEndCompileRequest::setMatrixLayoutMode(SlangMatrixLayoutMode mode) +{ + getOptionSet().setMatrixLayoutMode((MatrixLayoutMode)mode); +} + +void EndToEndCompileRequest::setTargetMatrixLayoutMode(int targetIndex, SlangMatrixLayoutMode mode) +{ + getTargetOptionSet(targetIndex).setMatrixLayoutMode(MatrixLayoutMode(mode)); +} + +void EndToEndCompileRequest::setTargetGenerateWholeProgram(int targetIndex, bool value) +{ + getTargetOptionSet(targetIndex).set(CompilerOptionName::GenerateWholeProgram, value); +} + +void EndToEndCompileRequest::setTargetEmbedDownstreamIR(int targetIndex, bool value) +{ + getTargetOptionSet(targetIndex).set(CompilerOptionName::EmbedDownstreamIR, value); +} + +void EndToEndCompileRequest::setTargetLineDirectiveMode( + SlangInt targetIndex, + SlangLineDirectiveMode mode) +{ + getTargetOptionSet(targetIndex) + .set(CompilerOptionName::LineDirectiveMode, LineDirectiveMode(mode)); +} + +void EndToEndCompileRequest::overrideDiagnosticSeverity( + SlangInt messageID, + SlangSeverity overrideSeverity) +{ + getSink()->overrideDiagnosticSeverity(int(messageID), Severity(overrideSeverity)); +} + +SlangDiagnosticFlags EndToEndCompileRequest::getDiagnosticFlags() +{ + DiagnosticSink::Flags sinkFlags = getSink()->getFlags(); + + SlangDiagnosticFlags flags = 0; + + if (sinkFlags & DiagnosticSink::Flag::VerbosePath) + flags |= SLANG_DIAGNOSTIC_FLAG_VERBOSE_PATHS; + + if (sinkFlags & DiagnosticSink::Flag::TreatWarningsAsErrors) + flags |= SLANG_DIAGNOSTIC_FLAG_TREAT_WARNINGS_AS_ERRORS; + + return flags; +} + +void EndToEndCompileRequest::setDiagnosticFlags(SlangDiagnosticFlags flags) +{ + DiagnosticSink::Flags sinkFlags = getSink()->getFlags(); + + if (flags & SLANG_DIAGNOSTIC_FLAG_VERBOSE_PATHS) + sinkFlags |= DiagnosticSink::Flag::VerbosePath; + else + sinkFlags &= ~DiagnosticSink::Flag::VerbosePath; + + if (flags & SLANG_DIAGNOSTIC_FLAG_TREAT_WARNINGS_AS_ERRORS) + sinkFlags |= DiagnosticSink::Flag::TreatWarningsAsErrors; + else + sinkFlags &= ~DiagnosticSink::Flag::TreatWarningsAsErrors; + + getSink()->setFlags(sinkFlags); +} + +SlangResult EndToEndCompileRequest::addTargetCapability( + SlangInt targetIndex, + SlangCapabilityID capability) +{ + auto& targets = getLinkage()->targets; + if (targetIndex < 0 || targetIndex >= targets.getCount()) + return SLANG_E_INVALID_ARG; + getTargetOptionSet(targetIndex).addCapabilityAtom(CapabilityName(capability)); + return SLANG_OK; +} + +void EndToEndCompileRequest::setDebugInfoLevel(SlangDebugInfoLevel level) +{ + getOptionSet().set(CompilerOptionName::DebugInformation, DebugInfoLevel(level)); +} + +void EndToEndCompileRequest::setDebugInfoFormat(SlangDebugInfoFormat format) +{ + getOptionSet().set(CompilerOptionName::DebugInformationFormat, DebugInfoFormat(format)); +} + +void EndToEndCompileRequest::setOptimizationLevel(SlangOptimizationLevel level) +{ + getOptionSet().set(CompilerOptionName::Optimization, OptimizationLevel(level)); +} + +void EndToEndCompileRequest::setOutputContainerFormat(SlangContainerFormat format) +{ + m_containerFormat = ContainerFormat(format); +} + +void EndToEndCompileRequest::setPassThrough(SlangPassThrough inPassThrough) +{ + m_passThrough = PassThroughMode(inPassThrough); +} + +void EndToEndCompileRequest::setReportDownstreamTime(bool value) +{ + getOptionSet().set(CompilerOptionName::ReportDownstreamTime, value); +} + +void EndToEndCompileRequest::setReportPerfBenchmark(bool value) +{ + getOptionSet().set(CompilerOptionName::ReportPerfBenchmark, value); +} + +void EndToEndCompileRequest::setSkipSPIRVValidation(bool value) +{ + getOptionSet().set(CompilerOptionName::SkipSPIRVValidation, value); +} + +void EndToEndCompileRequest::setTargetUseMinimumSlangOptimization(int targetIndex, bool value) +{ + getTargetOptionSet(targetIndex).set(CompilerOptionName::MinimumSlangOptimization, value); +} + +void EndToEndCompileRequest::setIgnoreCapabilityCheck(bool value) +{ + getOptionSet().set(CompilerOptionName::IgnoreCapabilities, value); +} + +void EndToEndCompileRequest::setDiagnosticCallback( + SlangDiagnosticCallback callback, + void const* userData) +{ + ComPtr<ISlangWriter> writer(new CallbackWriter(callback, userData, WriterFlag::IsConsole)); + setWriter(WriterChannel::Diagnostic, writer); +} + +void EndToEndCompileRequest::setWriter(SlangWriterChannel chan, ISlangWriter* writer) +{ + setWriter(WriterChannel(chan), writer); +} + +ISlangWriter* EndToEndCompileRequest::getWriter(SlangWriterChannel chan) +{ + return getWriter(WriterChannel(chan)); +} + +void EndToEndCompileRequest::addSearchPath(const char* path) +{ + getOptionSet().addSearchPath(path); +} + +void EndToEndCompileRequest::addPreprocessorDefine(const char* key, const char* value) +{ + getOptionSet().addPreprocessorDefine(key, value); +} + +void EndToEndCompileRequest::setEnableEffectAnnotations(bool value) +{ + getOptionSet().set(CompilerOptionName::EnableEffectAnnotations, value); +} + +char const* EndToEndCompileRequest::getDiagnosticOutput() +{ + return m_diagnosticOutput.begin(); +} + +SlangResult EndToEndCompileRequest::getDiagnosticOutputBlob(ISlangBlob** outBlob) +{ + if (!outBlob) + return SLANG_E_INVALID_ARG; + + if (!m_diagnosticOutputBlob) + { + m_diagnosticOutputBlob = StringUtil::createStringBlob(m_diagnosticOutput); + } + + ComPtr<ISlangBlob> resultBlob = m_diagnosticOutputBlob; + *outBlob = resultBlob.detach(); + return SLANG_OK; +} + +int EndToEndCompileRequest::addTranslationUnit(SlangSourceLanguage language, char const* inName) +{ + auto frontEndReq = getFrontEndReq(); + NamePool* namePool = frontEndReq->getNamePool(); + + // Work out a module name. Can be nullptr if so will generate a name + Name* moduleName = inName ? namePool->getName(inName) : frontEndReq->m_defaultModuleName; + + // If moduleName is nullptr a name will be generated + return frontEndReq->addTranslationUnit(Slang::SourceLanguage(language), moduleName); +} + +void EndToEndCompileRequest::setDefaultModuleName(const char* defaultModuleName) +{ + auto frontEndReq = getFrontEndReq(); + NamePool* namePool = frontEndReq->getNamePool(); + frontEndReq->m_defaultModuleName = namePool->getName(defaultModuleName); +} + +SlangResult _addLibraryReference( + EndToEndCompileRequest* req, + ModuleLibrary* moduleLibrary, + bool includeEntryPoint) +{ + FrontEndCompileRequest* frontEndRequest = req->getFrontEndReq(); + + if (includeEntryPoint) + { + frontEndRequest->m_extraEntryPoints.addRange( + moduleLibrary->m_entryPoints.getBuffer(), + moduleLibrary->m_entryPoints.getCount()); + } + + for (auto m : moduleLibrary->m_modules) + { + RefPtr<TranslationUnitRequest> tu = new TranslationUnitRequest(frontEndRequest, m); + frontEndRequest->translationUnits.add(tu); + // For modules loaded for EndToEndCompileRequest, + // we don't need the automatically discovered entrypoints. + if (!includeEntryPoint) + m->getEntryPoints().clear(); + } + return SLANG_OK; +} + +SlangResult _addLibraryReference( + EndToEndCompileRequest* req, + String path, + IArtifact* artifact, + bool includeEntryPoint) +{ + auto desc = artifact->getDesc(); + + // TODO(JS): + // This isn't perhaps the best way to handle this scenario, as IArtifact can + // support lazy evaluation, with suitable hander. + // For now we just read in and strip out the bits we want. + if (isDerivedFrom(desc.kind, ArtifactKind::Container) && + isDerivedFrom(desc.payload, ArtifactPayload::CompileResults)) + { + // We want to read as a file system + ComPtr<IArtifact> container; + + SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::readContainer(artifact, container)); + + // Find the payload... It should be linkable + if (!ArtifactDescUtil::isLinkable(container->getDesc())) + { + return SLANG_FAIL; + } + + ComPtr<IModuleLibrary> libraryIntf; + SLANG_RETURN_ON_FAIL( + loadModuleLibrary(ArtifactKeep::Yes, container, path, req, libraryIntf)); + + auto library = as<ModuleLibrary>(libraryIntf); + + // Look for source maps + for (auto associated : container->getAssociated()) + { + auto assocDesc = associated->getDesc(); + + // If we find an obfuscated source map load it and associate + if (isDerivedFrom(assocDesc.kind, ArtifactKind::Json) && + isDerivedFrom(assocDesc.payload, ArtifactPayload::SourceMap) && + isDerivedFrom(assocDesc.style, ArtifactStyle::Obfuscated)) + { + ComPtr<ICastable> castable; + SLANG_RETURN_ON_FAIL(associated->getOrCreateRepresentation( + SourceMap::getTypeGuid(), + ArtifactKeep::Yes, + castable.writeRef())); + auto sourceMapBox = asBoxValue<SourceMap>(castable); + SLANG_ASSERT(sourceMapBox); + + // TODO(JS): + // There is perhaps (?) a risk here that we might copy the obfuscated map + // into some output container. Currently that only happens for source maps + // that are from translation units. + // + // On the other hand using "import" is a way that such source maps *would* be + // copied into the output, and that is something that could be a vector + // for leaking. + // + // That isn't a risk from -r though because, it doesn't create a translation + // unit(s). + for (auto module : library->m_modules) + { + module->getIRModule()->setObfuscatedSourceMap(sourceMapBox); + } + + // Look up the source file + auto sourceManager = req->getSink()->getSourceManager(); + + auto name = Path::getFileNameWithoutExt(associated->getName()); + + if (name.getLength()) + { + // Note(tfoley): There is a subtle requirement here, that any + // source file `name` that might be searched for here *must* + // have been added to the `sourceManager` already, as a + // byproduct of debug source location information getting + // deserialized as part of the call to `loadModuleLibrary()` above. + // + // The implicit dependency is frustrating, and could potentially + // break if somehow the debug info chunk was stripped from a binary, + // while the source map was left in (which should be valid, even if + // it is unlikely to be what a user wants). + // + // Ideally the source map would either be made an integral part of + // the debug source location chunk, so they are loaded together, + // or the `SourceManager` would be adapted so that it can store + // registered source maps independent of whether or not the + // corresponding source file(s) have been loaded. + + auto sourceFile = sourceManager->findSourceFileByPathRecursively(name); + SLANG_ASSERT(sourceFile); + sourceFile->setSourceMap(sourceMapBox, SourceMapKind::Obfuscated); + } + } + } + + SLANG_RETURN_ON_FAIL(_addLibraryReference(req, library, includeEntryPoint)); + return SLANG_OK; + } + + if (desc.kind == ArtifactKind::Library && desc.payload == ArtifactPayload::SlangIR) + { + ComPtr<IModuleLibrary> libraryIntf; + + SLANG_RETURN_ON_FAIL( + loadModuleLibrary(ArtifactKeep::Yes, artifact, path, req, libraryIntf)); + + auto library = as<ModuleLibrary>(libraryIntf); + if (!library) + { + return SLANG_FAIL; + } + + SLANG_RETURN_ON_FAIL(_addLibraryReference(req, library, includeEntryPoint)); + return SLANG_OK; + } + + // TODO(JS): + // Do we want to check the path exists? + + // Add to the m_libModules + auto linkage = req->getLinkage(); + linkage->m_libModules.add(ComPtr<IArtifact>(artifact)); + + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::addLibraryReference( + const char* basePath, + const void* libData, + size_t libDataSize) +{ + // We need to deserialize and add the modules + ComPtr<IModuleLibrary> library; + + auto libBlob = RawBlob::create((const Byte*)libData, libDataSize); + + SLANG_RETURN_ON_FAIL( + loadModuleLibrary(libBlob, (const Byte*)libData, libDataSize, basePath, this, library)); + + // Create an artifact without any name (as one is not provided) + auto artifact = + Artifact::create(ArtifactDesc::make(ArtifactKind::Library, ArtifactPayload::SlangIR)); + artifact->addRepresentation(library); + + return _addLibraryReference(this, basePath, artifact, true); +} + +void EndToEndCompileRequest::addTranslationUnitPreprocessorDefine( + int translationUnitIndex, + const char* key, + const char* value) +{ + getFrontEndReq()->translationUnits[translationUnitIndex]->preprocessorDefinitions[key] = value; +} + +void EndToEndCompileRequest::addTranslationUnitSourceFile( + int translationUnitIndex, + char const* path) +{ + auto frontEndReq = getFrontEndReq(); + if (!path) + return; + if (translationUnitIndex < 0) + return; + if (Index(translationUnitIndex) >= frontEndReq->translationUnits.getCount()) + return; + + frontEndReq->addTranslationUnitSourceFile(translationUnitIndex, path); +} + +void EndToEndCompileRequest::addTranslationUnitSourceString( + int translationUnitIndex, + char const* path, + char const* source) +{ + if (!source) + return; + addTranslationUnitSourceStringSpan(translationUnitIndex, path, source, source + strlen(source)); +} + +void EndToEndCompileRequest::addTranslationUnitSourceStringSpan( + int translationUnitIndex, + char const* path, + char const* sourceBegin, + char const* sourceEnd) +{ + auto frontEndReq = getFrontEndReq(); + if (!sourceBegin) + return; + if (translationUnitIndex < 0) + return; + if (Index(translationUnitIndex) >= frontEndReq->translationUnits.getCount()) + return; + + if (!path) + path = ""; + + const auto slice = UnownedStringSlice(sourceBegin, sourceEnd); + + auto blob = RawBlob::create(slice.begin(), slice.getLength()); + + frontEndReq->addTranslationUnitSourceBlob(translationUnitIndex, path, blob); +} + +void EndToEndCompileRequest::addTranslationUnitSourceBlob( + int translationUnitIndex, + char const* path, + ISlangBlob* sourceBlob) +{ + auto frontEndReq = getFrontEndReq(); + if (!sourceBlob) + return; + if (translationUnitIndex < 0) + return; + if (Slang::Index(translationUnitIndex) >= frontEndReq->translationUnits.getCount()) + return; + + if (!path) + path = ""; + + frontEndReq->addTranslationUnitSourceBlob(translationUnitIndex, path, sourceBlob); +} + + +int EndToEndCompileRequest::addEntryPoint( + int translationUnitIndex, + char const* name, + SlangStage stage) +{ + return addEntryPointEx(translationUnitIndex, name, stage, 0, nullptr); +} + +int EndToEndCompileRequest::addEntryPointEx( + int translationUnitIndex, + char const* name, + SlangStage stage, + int genericParamTypeNameCount, + char const** genericParamTypeNames) +{ + auto frontEndReq = getFrontEndReq(); + if (!name) + return -1; + if (translationUnitIndex < 0) + return -1; + if (Index(translationUnitIndex) >= frontEndReq->translationUnits.getCount()) + return -1; + + List<String> typeNames; + for (int i = 0; i < genericParamTypeNameCount; i++) + typeNames.add(genericParamTypeNames[i]); + + return addEntryPoint(translationUnitIndex, name, Profile(Stage(stage)), typeNames); +} + +SlangResult EndToEndCompileRequest::setGlobalGenericArgs( + int genericArgCount, + char const** genericArgs) +{ + auto& argStrings = m_globalSpecializationArgStrings; + argStrings.clear(); + for (int i = 0; i < genericArgCount; i++) + argStrings.add(genericArgs[i]); + + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::setTypeNameForGlobalExistentialTypeParam( + int slotIndex, + char const* typeName) +{ + if (slotIndex < 0) + return SLANG_FAIL; + if (!typeName) + return SLANG_FAIL; + + auto& typeArgStrings = m_globalSpecializationArgStrings; + if (Index(slotIndex) >= typeArgStrings.getCount()) + typeArgStrings.setCount(slotIndex + 1); + typeArgStrings[slotIndex] = String(typeName); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::setTypeNameForEntryPointExistentialTypeParam( + int entryPointIndex, + int slotIndex, + char const* typeName) +{ + if (entryPointIndex < 0) + return SLANG_FAIL; + if (slotIndex < 0) + return SLANG_FAIL; + if (!typeName) + return SLANG_FAIL; + + if (Index(entryPointIndex) >= m_entryPoints.getCount()) + return SLANG_FAIL; + + auto& entryPointInfo = m_entryPoints[entryPointIndex]; + auto& typeArgStrings = entryPointInfo.specializationArgStrings; + if (Index(slotIndex) >= typeArgStrings.getCount()) + typeArgStrings.setCount(slotIndex + 1); + typeArgStrings[slotIndex] = String(typeName); + return SLANG_OK; +} + +void EndToEndCompileRequest::setAllowGLSLInput(bool value) +{ + getOptionSet().set(CompilerOptionName::AllowGLSL, value); +} + +SlangResult EndToEndCompileRequest::compile() +{ + SlangResult res = SLANG_FAIL; + double downstreamStartTime = 0.0; + double totalStartTime = 0.0; + + if (getOptionSet().getBoolOption(CompilerOptionName::ReportDownstreamTime)) + { + getSession()->getCompilerElapsedTime(&totalStartTime, &downstreamStartTime); + PerformanceProfiler::getProfiler()->clear(); + } +#if !defined(SLANG_DEBUG_INTERNAL_ERROR) + // By default we'd like to catch as many internal errors as possible, + // and report them to the user nicely (rather than just crash their + // application). Internally Slang currently uses exceptions for this. + // + // TODO: Consider using `setjmp()`-style escape so that we can work + // with applications that disable exceptions. + // + // TODO: Consider supporting Windows "Structured Exception Handling" + // so that we can also recover from a wider class of crashes. + + try + { + SLANG_PROFILE_SECTION(compileInner); + res = executeActions(); + } + catch (const AbortCompilationException& e) + { + // This situation indicates a fatal (but not necessarily internal) error + // that forced compilation to terminate. There should already have been + // a diagnostic produced, so we don't need to add one here. + if (getSink()->getErrorCount() == 0) + { + // If for some reason we didn't output any diagnostic, something is + // going wrong, but we want to make sure we at least output something. + getSink()->diagnose( + SourceLoc(), + Diagnostics::compilationAbortedDueToException, + typeid(e).name(), + e.Message); + } + } + catch (const Exception& e) + { + // The compiler failed due to an internal error that was detected. + // We will print out information on the exception to help out the user + // in either filing a bug, or locating what in their code created + // a problem. + getSink()->diagnose( + SourceLoc(), + Diagnostics::compilationAbortedDueToException, + typeid(e).name(), + e.Message); + } + catch (...) + { + // The compiler failed due to some exception that wasn't a sublass of + // `Exception`, so something really fishy is going on. We want to + // let the user know that we messed up, so they know to blame Slang + // and not some other component in their system. + getSink()->diagnose(SourceLoc(), Diagnostics::compilationAborted); + } + m_diagnosticOutput = getSink()->outputBuffer.produceString(); + +#else + // When debugging, we probably don't want to filter out any errors, since + // we are probably trying to root-cause and *fix* those errors. + { + res = req->executeActions(); + } +#endif + + if (getOptionSet().getBoolOption(CompilerOptionName::ReportDownstreamTime)) + { + double downstreamEndTime = 0; + double totalEndTime = 0; + getSession()->getCompilerElapsedTime(&totalEndTime, &downstreamEndTime); + double downstreamTime = downstreamEndTime - downstreamStartTime; + String downstreamTimeStr = String(downstreamTime, "%.2f"); + getSink()->diagnose(SourceLoc(), Diagnostics::downstreamCompileTime, downstreamTimeStr); + } + if (getOptionSet().getBoolOption(CompilerOptionName::ReportPerfBenchmark)) + { + StringBuilder perfResult; + PerformanceProfiler::getProfiler()->getResult(perfResult); + perfResult << "\nType Dictionary Size: " << getSession()->m_typeDictionarySize << "\n"; + getSink()->diagnose( + SourceLoc(), + Diagnostics::performanceBenchmarkResult, + perfResult.produceString()); + } + + // Repro dump handling + { + auto dumpRepro = getOptionSet().getStringOption(CompilerOptionName::DumpRepro); + auto dumpReproOnError = getOptionSet().getBoolOption(CompilerOptionName::DumpReproOnError); + + if (dumpRepro.getLength()) + { + SlangResult saveRes = ReproUtil::saveState(this, dumpRepro); + if (SLANG_FAILED(saveRes)) + { + getSink()->diagnose(SourceLoc(), Diagnostics::unableToWriteReproFile, dumpRepro); + return saveRes; + } + } + else if (dumpReproOnError && SLANG_FAILED(res)) + { + String reproFileName; + SlangResult saveRes = SLANG_FAIL; + + RefPtr<Stream> stream; + if (SLANG_SUCCEEDED(ReproUtil::findUniqueReproDumpStream(this, reproFileName, stream))) + { + saveRes = ReproUtil::saveState(this, stream); + } + + if (SLANG_FAILED(saveRes)) + { + getSink()->diagnose( + SourceLoc(), + Diagnostics::unableToWriteReproFile, + reproFileName); + } + } + } + + auto reflectionPath = getOptionSet().getStringOption(CompilerOptionName::EmitReflectionJSON); + if (reflectionPath.getLength() != 0) + { + auto bufferWriter = PrettyWriter(); + emitReflectionJSON(this, this->getReflection(), bufferWriter); + if (reflectionPath == "-") + { + auto builder = bufferWriter.getBuilder(); + StdWriters::getOut().write(builder.getBuffer(), builder.getLength()); + } + else if (SLANG_FAILED(File::writeAllText(reflectionPath, bufferWriter.getBuilder()))) + { + getSink()->diagnose(SourceLoc(), Diagnostics::unableToWriteFile, reflectionPath); + } + } + + return res; +} + +int EndToEndCompileRequest::getDependencyFileCount() +{ + auto frontEndReq = getFrontEndReq(); + auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); + return (int)program->getFileDependencies().getCount(); +} + +char const* EndToEndCompileRequest::getDependencyFilePath(int index) +{ + auto frontEndReq = getFrontEndReq(); + auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); + SourceFile* sourceFile = program->getFileDependencies()[index]; + return sourceFile->getPathInfo().hasFoundPath() + ? sourceFile->getPathInfo().getMostUniqueIdentity().getBuffer() + : "unknown"; +} + +int EndToEndCompileRequest::getTranslationUnitCount() +{ + return (int)getFrontEndReq()->translationUnits.getCount(); +} + +void const* EndToEndCompileRequest::getEntryPointCode(int entryPointIndex, size_t* outSize) +{ + // Zero the size initially, in case need to return nullptr for error. + if (outSize) + { + *outSize = 0; + } + + auto linkage = getLinkage(); + auto program = getSpecializedGlobalAndEntryPointsComponentType(); + + // TODO: We should really accept a target index in this API + Index targetIndex = 0; + auto targetCount = linkage->targets.getCount(); + if (targetIndex >= targetCount) + return nullptr; + auto targetReq = linkage->targets[targetIndex]; + + + if (entryPointIndex < 0) + return nullptr; + if (Index(entryPointIndex) >= program->getEntryPointCount()) + return nullptr; + auto entryPoint = program->getEntryPoint(entryPointIndex); + + auto targetProgram = program->getTargetProgram(targetReq); + if (!targetProgram) + return nullptr; + IArtifact* artifact = targetProgram->getExistingEntryPointResult(entryPointIndex); + if (!artifact) + { + return nullptr; + } + + ComPtr<ISlangBlob> blob; + SLANG_RETURN_NULL_ON_FAIL(artifact->loadBlob(ArtifactKeep::Yes, blob.writeRef())); + + if (outSize) + { + *outSize = blob->getBufferSize(); + } + + return (void*)blob->getBufferPointer(); +} + +SlangResult EndToEndCompileRequest::getCompileTimeProfile( + ISlangProfiler** compileTimeProfile, + bool shouldClear) +{ + if (compileTimeProfile == nullptr) + { + return SLANG_E_INVALID_ARG; + } + + SlangProfiler* profiler = new SlangProfiler(PerformanceProfiler::getProfiler()); + + if (shouldClear) + { + PerformanceProfiler::getProfiler()->clear(); + } + + ComPtr<ISlangProfiler> result(profiler); + *compileTimeProfile = result.detach(); + return SLANG_OK; +} + +static SlangResult _getEntryPointResult( + EndToEndCompileRequest* req, + int entryPointIndex, + int targetIndex, + ComPtr<IArtifact>& outArtifact) +{ + auto linkage = req->getLinkage(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); + + Index targetCount = linkage->targets.getCount(); + if ((targetIndex < 0) || (targetIndex >= targetCount)) + { + return SLANG_E_INVALID_ARG; + } + auto targetReq = linkage->targets[targetIndex]; + + // Get the entry point count on the program, rather than (say) req->m_entryPoints.getCount() + // because + // 1) The entry point is fetched from the program anyway so must be consistent + // 2) The req may not have all entry points (for example when an entry point is in a module) + const Index entryPointCount = program->getEntryPointCount(); + + if ((entryPointIndex < 0) || (entryPointIndex >= entryPointCount)) + { + return SLANG_E_INVALID_ARG; + } + auto entryPointReq = program->getEntryPoint(entryPointIndex); + + auto targetProgram = program->getTargetProgram(targetReq); + if (!targetProgram) + return SLANG_FAIL; + + outArtifact = targetProgram->getExistingEntryPointResult(entryPointIndex); + return SLANG_OK; +} + +static SlangResult _getWholeProgramResult( + EndToEndCompileRequest* req, + int targetIndex, + ComPtr<IArtifact>& outArtifact) +{ + auto linkage = req->getLinkage(); + auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); + + if (!program) + { + return SLANG_FAIL; + } + + Index targetCount = linkage->targets.getCount(); + if ((targetIndex < 0) || (targetIndex >= targetCount)) + { + return SLANG_E_INVALID_ARG; + } + auto targetReq = linkage->targets[targetIndex]; + + auto targetProgram = program->getTargetProgram(targetReq); + if (!targetProgram) + return SLANG_FAIL; + outArtifact = targetProgram->getExistingWholeProgramResult(); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::getEntryPointCodeBlob( + int entryPointIndex, + int targetIndex, + ISlangBlob** outBlob) +{ + if (!outBlob) + return SLANG_E_INVALID_ARG; + ComPtr<IArtifact> artifact; + SLANG_RETURN_ON_FAIL(_getEntryPointResult(this, entryPointIndex, targetIndex, artifact)); + SLANG_RETURN_ON_FAIL(artifact->loadBlob(ArtifactKeep::Yes, outBlob)); + + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary) +{ + if (!outSharedLibrary) + return SLANG_E_INVALID_ARG; + ComPtr<IArtifact> artifact; + SLANG_RETURN_ON_FAIL(_getEntryPointResult(this, entryPointIndex, targetIndex, artifact)); + SLANG_RETURN_ON_FAIL(artifact->loadSharedLibrary(ArtifactKeep::Yes, outSharedLibrary)); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::getTargetCodeBlob(int targetIndex, ISlangBlob** outBlob) +{ + if (!outBlob) + return SLANG_E_INVALID_ARG; + + ComPtr<IArtifact> artifact; + SLANG_RETURN_ON_FAIL(_getWholeProgramResult(this, targetIndex, artifact)); + SLANG_RETURN_ON_FAIL(artifact->loadBlob(ArtifactKeep::Yes, outBlob)); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::getTargetHostCallable( + int targetIndex, + ISlangSharedLibrary** outSharedLibrary) +{ + if (!outSharedLibrary) + return SLANG_E_INVALID_ARG; + + ComPtr<IArtifact> artifact; + SLANG_RETURN_ON_FAIL(_getWholeProgramResult(this, targetIndex, artifact)); + SLANG_RETURN_ON_FAIL(artifact->loadSharedLibrary(ArtifactKeep::Yes, outSharedLibrary)); + return SLANG_OK; +} + +char const* EndToEndCompileRequest::getEntryPointSource(int entryPointIndex) +{ + return (char const*)getEntryPointCode(entryPointIndex, nullptr); +} + +ISlangMutableFileSystem* EndToEndCompileRequest::getCompileRequestResultAsFileSystem() +{ + if (!m_containerFileSystem) + { + if (m_containerArtifact) + { + ComPtr<ISlangMutableFileSystem> fileSystem(new MemoryFileSystem); + + // Filter the containerArtifact into things that can be written + ComPtr<IArtifact> writeArtifact; + if (SLANG_SUCCEEDED( + ArtifactContainerUtil::filter(m_containerArtifact, writeArtifact)) && + writeArtifact) + { + if (SLANG_SUCCEEDED( + ArtifactContainerUtil::writeContainer(writeArtifact, "", fileSystem))) + { + m_containerFileSystem.swap(fileSystem); + } + } + } + } + + return m_containerFileSystem; +} + +void const* EndToEndCompileRequest::getCompileRequestCode(size_t* outSize) +{ + if (m_containerArtifact) + { + ComPtr<ISlangBlob> containerBlob; + if (SLANG_SUCCEEDED( + m_containerArtifact->loadBlob(ArtifactKeep::Yes, containerBlob.writeRef()))) + { + *outSize = containerBlob->getBufferSize(); + return containerBlob->getBufferPointer(); + } + } + + // Container blob does not have any contents + *outSize = 0; + return nullptr; +} + +SlangResult EndToEndCompileRequest::getContainerCode(ISlangBlob** outBlob) +{ + if (m_containerArtifact) + { + ComPtr<ISlangBlob> containerBlob; + if (SLANG_SUCCEEDED( + m_containerArtifact->loadBlob(ArtifactKeep::Yes, containerBlob.writeRef()))) + { + *outBlob = containerBlob.detach(); + return SLANG_OK; + } + } + return SLANG_FAIL; +} + +SlangResult EndToEndCompileRequest::loadRepro( + ISlangFileSystem* fileSystem, + const void* data, + size_t size) +{ + List<uint8_t> buffer; + SLANG_RETURN_ON_FAIL(ReproUtil::loadState((const uint8_t*)data, size, getSink(), buffer)); + + MemoryOffsetBase base; + base.set(buffer.getBuffer(), buffer.getCount()); + + ReproUtil::RequestState* requestState = ReproUtil::getRequest(buffer); + + SLANG_RETURN_ON_FAIL(ReproUtil::load(base, requestState, fileSystem, this)); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::saveRepro(ISlangBlob** outBlob) +{ + OwnedMemoryStream stream(FileAccess::Write); + + SLANG_RETURN_ON_FAIL(ReproUtil::saveState(this, &stream)); + + // Put the content of the stream in the blob + + List<uint8_t> data; + stream.swapContents(data); + + *outBlob = ListBlob::moveCreate(data).detach(); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::enableReproCapture() +{ + getLinkage()->setRequireCacheFileSystem(true); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::processCommandLineArguments( + char const* const* args, + int argCount) +{ + return parseOptions(this, argCount, args); +} + +SlangReflection* EndToEndCompileRequest::getReflection() +{ + auto linkage = getLinkage(); + auto program = getSpecializedGlobalAndEntryPointsComponentType(); + + // Note(tfoley): The API signature doesn't let the client + // specify which target they want to access reflection + // information for, so for now we default to the first one. + // + // TODO: Add a new `spGetReflectionForTarget(req, targetIndex)` + // so that we can do this better, and make it clear that + // `spGetReflection()` is shorthand for `targetIndex == 0`. + // + Slang::Index targetIndex = 0; + auto targetCount = linkage->targets.getCount(); + if (targetIndex >= targetCount) + return nullptr; + + auto targetReq = linkage->targets[targetIndex]; + auto targetProgram = program->getTargetProgram(targetReq); + + + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + auto programLayout = targetProgram->getOrCreateLayout(&sink); + + return (SlangReflection*)programLayout; +} + +SlangResult EndToEndCompileRequest::getProgram(slang::IComponentType** outProgram) +{ + auto program = getSpecializedGlobalComponentType(); + *outProgram = Slang::ComPtr<slang::IComponentType>(program).detach(); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::getProgramWithEntryPoints(slang::IComponentType** outProgram) +{ + auto program = getSpecializedGlobalAndEntryPointsComponentType(); + *outProgram = Slang::ComPtr<slang::IComponentType>(program).detach(); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::getModule( + SlangInt translationUnitIndex, + slang::IModule** outModule) +{ + auto module = getFrontEndReq()->getTranslationUnit(translationUnitIndex)->getModule(); + + *outModule = Slang::ComPtr<slang::IModule>(module).detach(); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::getSession(slang::ISession** outSession) +{ + auto session = getLinkage(); + *outSession = Slang::ComPtr<slang::ISession>(session).detach(); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::getEntryPoint( + SlangInt entryPointIndex, + slang::IComponentType** outEntryPoint) +{ + auto entryPoint = getSpecializedEntryPointComponentType(entryPointIndex); + *outEntryPoint = Slang::ComPtr<slang::IComponentType>(entryPoint).detach(); + return SLANG_OK; +} + +SlangResult EndToEndCompileRequest::isParameterLocationUsed( + Int entryPointIndex, + Int targetIndex, + SlangParameterCategory category, + UInt spaceIndex, + UInt registerIndex, + bool& outUsed) +{ + if (!ShaderBindingRange::isUsageTracked((slang::ParameterCategory)category)) + return SLANG_E_NOT_AVAILABLE; + + ComPtr<IArtifact> artifact; + if (SLANG_FAILED(_getEntryPointResult( + this, + static_cast<int>(entryPointIndex), + static_cast<int>(targetIndex), + artifact))) + return SLANG_E_INVALID_ARG; + + if (!artifact) + return SLANG_E_NOT_AVAILABLE; + + // Find a rep + auto metadata = findAssociatedRepresentation<IArtifactPostEmitMetadata>(artifact); + if (!metadata) + return SLANG_E_NOT_AVAILABLE; + + return metadata->isParameterLocationUsed(category, spaceIndex, registerIndex, outUsed); +} + +} // namespace Slang diff --git a/source/slang/slang-end-to-end-request.h b/source/slang/slang-end-to-end-request.h new file mode 100644 index 000000000..af04eac63 --- /dev/null +++ b/source/slang/slang-end-to-end-request.h @@ -0,0 +1,416 @@ +// slang-end-to-end-request.h +#pragma once + +// +// This file provides the `EndToEndCompileRequest` type and +// related utilities. +// +// The primary purpose of `EndToEndCompileRequest` is to +// implement the overall flow of compilation for the +// `slangc` command-line tool. Command-line compiles need +// to deal with various details that aren't relevant to +// API-based compiles (e.g., writing output to files), +// and also need to implement a lot of complicated +// "do what I mean" policy that is expected by users of +// `slangc` but that can be dangerously implicit when +// that policy is enshrined in an API. +// +// In addition to serving the needs of `slangc`, the +// `EndToEndCompileRequest` type also implements the +// deprecated `slang::ICompileRequest` interface from +// the public API. +// + +#include "../compiler-core/slang-source-embed-util.h" +#include "../core/slang-file-system.h" +#include "slang-compile-request.h" + +namespace Slang +{ + +enum class ContainerFormat : SlangContainerFormatIntegral +{ + None = SLANG_CONTAINER_FORMAT_NONE, + SlangModule = SLANG_CONTAINER_FORMAT_SLANG_MODULE, +}; + +// TODO: everything related to `StdWriters` should be removed. +enum class WriterChannel : SlangWriterChannelIntegral +{ + Diagnostic = SLANG_WRITER_CHANNEL_DIAGNOSTIC, + StdOutput = SLANG_WRITER_CHANNEL_STD_OUTPUT, + StdError = SLANG_WRITER_CHANNEL_STD_ERROR, + CountOf = SLANG_WRITER_CHANNEL_COUNT_OF, +}; + +// TODO: everything related to `StdWriters` should be removed. +enum class WriterMode : SlangWriterModeIntegral +{ + Text = SLANG_WRITER_MODE_TEXT, + Binary = SLANG_WRITER_MODE_BINARY, +}; + +/// A compile request that spans the front and back ends of the compiler +/// +/// This is what the command-line `slangc` uses, as well as the legacy +/// C API. It ties together the functionality of `Linkage`, +/// `FrontEndCompileRequest`, and `BackEndCompileRequest`, plus a small +/// number of additional features that primarily make sense for +/// command-line usage. +/// +class EndToEndCompileRequest : public RefObject, public slang::ICompileRequest +{ +public: + SLANG_CLASS_GUID(0xce6d2383, 0xee1b, 0x4fd7, {0xa0, 0xf, 0xb8, 0xb6, 0x33, 0x12, 0x95, 0xc8}) + + // ISlangUnknown + SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) + SLANG_OVERRIDE; + SLANG_REF_OBJECT_IUNKNOWN_ADD_REF + SLANG_REF_OBJECT_IUNKNOWN_RELEASE + + // slang::ICompileRequest + virtual SLANG_NO_THROW void SLANG_MCALL setFileSystem(ISlangFileSystem* fileSystem) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setCompileFlags(SlangCompileFlags flags) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangCompileFlags SLANG_MCALL getCompileFlags() SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDumpIntermediates(int enable) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDumpIntermediatePrefix(const char* prefix) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setEnableEffectAnnotations(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setLineDirectiveMode(SlangLineDirectiveMode mode) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setCodeGenTarget(SlangCompileTarget target) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL addCodeGenTarget(SlangCompileTarget target) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetProfile(int targetIndex, SlangProfileID profile) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setTargetFlags(int targetIndex, SlangTargetFlags flags) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetFloatingPointMode(int targetIndex, SlangFloatingPointMode mode) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetMatrixLayoutMode(int targetIndex, SlangMatrixLayoutMode mode) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetForceGLSLScalarBufferLayout(int targetIndex, bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setTargetForceDXLayout(int targetIndex, bool value) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetGenerateWholeProgram(int targetIndex, bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setTargetEmbedDownstreamIR(int targetIndex, bool value) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setMatrixLayoutMode(SlangMatrixLayoutMode mode) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoLevel(SlangDebugInfoLevel level) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setOptimizationLevel(SlangOptimizationLevel level) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setOutputContainerFormat(SlangContainerFormat format) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setPassThrough(SlangPassThrough passThrough) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setDiagnosticCallback(SlangDiagnosticCallback callback, void const* userData) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setWriter(SlangWriterChannel channel, ISlangWriter* writer) SLANG_OVERRIDE; + virtual SLANG_NO_THROW ISlangWriter* SLANG_MCALL getWriter(SlangWriterChannel channel) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL addSearchPath(const char* searchDir) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + addPreprocessorDefine(const char* key, const char* value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + processCommandLineArguments(char const* const* args, int argCount) SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL + addTranslationUnit(SlangSourceLanguage language, char const* name) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDefaultModuleName(const char* defaultModuleName) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitPreprocessorDefine( + int translationUnitIndex, + const char* key, + const char* value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + addTranslationUnitSourceFile(int translationUnitIndex, char const* path) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceString( + int translationUnitIndex, + char const* path, + char const* source) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL addLibraryReference( + const char* basePath, + const void* libData, + size_t libDataSize) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceStringSpan( + int translationUnitIndex, + char const* path, + char const* sourceBegin, + char const* sourceEnd) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceBlob( + int translationUnitIndex, + char const* path, + ISlangBlob* sourceBlob) SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL + addEntryPoint(int translationUnitIndex, char const* name, SlangStage stage) SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL addEntryPointEx( + int translationUnitIndex, + char const* name, + SlangStage stage, + int genericArgCount, + char const** genericArgs) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + setGlobalGenericArgs(int genericArgCount, char const** genericArgs) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + setTypeNameForGlobalExistentialTypeParam(int slotIndex, char const* typeName) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL setTypeNameForEntryPointExistentialTypeParam( + int entryPointIndex, + int slotIndex, + char const* typeName) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setAllowGLSLInput(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL compile() SLANG_OVERRIDE; + virtual SLANG_NO_THROW char const* SLANG_MCALL getDiagnosticOutput() SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getDiagnosticOutputBlob(ISlangBlob** outBlob) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL getDependencyFileCount() SLANG_OVERRIDE; + virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath(int index) SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL getTranslationUnitCount() SLANG_OVERRIDE; + virtual SLANG_NO_THROW char const* SLANG_MCALL getEntryPointSource(int entryPointIndex) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void const* SLANG_MCALL + getEntryPointCode(int entryPointIndex, size_t* outSize) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCodeBlob( + int entryPointIndex, + int targetIndex, + ISlangBlob** outBlob) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getTargetCodeBlob(int targetIndex, ISlangBlob** outBlob) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getTargetHostCallable(int targetIndex, ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void const* SLANG_MCALL getCompileRequestCode(size_t* outSize) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW ISlangMutableFileSystem* SLANG_MCALL + getCompileRequestResultAsFileSystem() SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getContainerCode(ISlangBlob** outBlob) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + loadRepro(ISlangFileSystem* fileSystem, const void* data, size_t size) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL saveRepro(ISlangBlob** outBlob) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL enableReproCapture() SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getProgram(slang::IComponentType** outProgram) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getEntryPoint(SlangInt entryPointIndex, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getModule(SlangInt translationUnitIndex, slang::IModule** outModule) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getSession(slang::ISession** outSession) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangReflection* SLANG_MCALL getReflection() SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setCommandLineCompilerMode() SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + addTargetCapability(SlangInt targetIndex, SlangCapabilityID capability) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getProgramWithEntryPoints(slang::IComponentType** outProgram) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL isParameterLocationUsed( + SlangInt entryPointIndex, + SlangInt targetIndex, + SlangParameterCategory category, + SlangUInt spaceIndex, + SlangUInt registerIndex, + bool& outUsed) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetLineDirectiveMode(SlangInt targetIndex, SlangLineDirectiveMode mode) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + overrideDiagnosticSeverity(SlangInt messageID, SlangSeverity overrideSeverity) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangDiagnosticFlags SLANG_MCALL getDiagnosticFlags() SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDiagnosticFlags(SlangDiagnosticFlags flags) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoFormat(SlangDebugInfoFormat format) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setReportDownstreamTime(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setReportPerfBenchmark(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setSkipSPIRVValidation(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetUseMinimumSlangOptimization(int targetIndex, bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setIgnoreCapabilityCheck(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getCompileTimeProfile(ISlangProfiler** compileTimeProfile, bool isClear) SLANG_OVERRIDE; + + void setTrackLiveness(bool v); + + EndToEndCompileRequest(Session* session); + + EndToEndCompileRequest(Linkage* linkage); + + ~EndToEndCompileRequest(); + + // If enabled will emit IR + bool m_emitIr = false; + + // What container format are we being asked to generate? + // If it's set to a format, the container blob will be calculated during compile + ContainerFormat m_containerFormat = ContainerFormat::None; + + /// Where the container is stored. This is calculated as part of compile if m_containerFormat is + /// set to a supported format. + ComPtr<IArtifact> m_containerArtifact; + /// Holds the container as a file system + ComPtr<ISlangMutableFileSystem> m_containerFileSystem; + + /// File system used by repro system if a file couldn't be found within the repro (or associated + /// directory) + ComPtr<ISlangFileSystem> m_reproFallbackFileSystem = + ComPtr<ISlangFileSystem>(OSFileSystem::getExtSingleton()); + + // Path to output container to + String m_containerOutputPath; + + // Should we just pass the input to another compiler? + PassThroughMode m_passThrough = PassThroughMode::None; + + /// If output should be source embedded, define the style of the embedding + SourceEmbedUtil::Style m_sourceEmbedStyle = SourceEmbedUtil::Style::None; + /// The language to be used for source embedding + SourceLanguage m_sourceEmbedLanguage = SourceLanguage::C; + /// Source embed variable name. Note may be used as a basis for names if multiple items written + String m_sourceEmbedName; + + /// Source code for the specialization arguments to use for the global specialization parameters + /// of the program. + List<String> m_globalSpecializationArgStrings; + + // Are we being driven by the command-line `slangc`, and should act accordingly? + bool m_isCommandLineCompile = false; + + String m_diagnosticOutput; + + /// A blob holding the diagnostic output + ComPtr<ISlangBlob> m_diagnosticOutputBlob; + + /// Per-entry-point information not tracked by other compile requests + class EntryPointInfo : public RefObject + { + public: + /// Source code for the specialization arguments to use for the specialization parameters of + /// the entry point. + List<String> specializationArgStrings; + }; + List<EntryPointInfo> m_entryPoints; + + /// Per-target information only needed for command-line compiles + class TargetInfo : public RefObject + { + public: + // Requested output paths for each entry point. + // An empty string indices no output desired for + // the given entry point. + Dictionary<Int, String> entryPointOutputPaths; + String wholeTargetOutputPath; + CompilerOptionSet targetOptions; + }; + Dictionary<TargetRequest*, RefPtr<TargetInfo>> m_targetInfos; + + CompilerOptionSet m_optionSetForDefaultTarget; + + CompilerOptionSet& getTargetOptionSet(TargetRequest* req); + + CompilerOptionSet& getTargetOptionSet(Index targetIndex); + + String m_dependencyOutputPath; + + /// Writes the modules in a container to the stream + SlangResult writeContainerToStream(Stream* stream); + + /// If a container format has been specified produce a container (stored in m_containerBlob) + SlangResult maybeCreateContainer(); + /// If a container has been constructed and the filename/path has contents will try to write + /// the container contents to the file + SlangResult maybeWriteContainer(const String& fileName); + + Linkage* getLinkage() { return m_linkage; } + + int addEntryPoint( + int translationUnitIndex, + String const& name, + Profile profile, + List<String> const& genericTypeNames); + + void setWriter(WriterChannel chan, ISlangWriter* writer); + ISlangWriter* getWriter(WriterChannel chan) const + { + return m_writers->getWriter(SlangWriterChannel(chan)); + } + + /// The end to end request can be passed as nullptr, if not driven by one + SlangResult executeActionsInner(); + SlangResult executeActions(); + + Session* getSession() { return m_session; } + DiagnosticSink* getSink() { return &m_sink; } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + + FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; } + + ComponentType* getUnspecializedGlobalComponentType() + { + return getFrontEndReq()->getGlobalComponentType(); + } + ComponentType* getUnspecializedGlobalAndEntryPointsComponentType() + { + return getFrontEndReq()->getGlobalAndEntryPointsComponentType(); + } + + ComponentType* getSpecializedGlobalComponentType() { return m_specializedGlobalComponentType; } + ComponentType* getSpecializedGlobalAndEntryPointsComponentType() + { + return m_specializedGlobalAndEntryPointsComponentType; + } + + ComponentType* getSpecializedEntryPointComponentType(Index index) + { + return m_specializedEntryPoints[index]; + } + + void writeArtifactToStandardOutput(IArtifact* artifact, DiagnosticSink* sink); + + void generateOutput(); + + CompilerOptionSet& getOptionSet() { return m_linkage->m_optionSet; } + +private: + String _getWholeProgramPath(TargetRequest* targetReq); + String _getEntryPointPath(TargetRequest* targetReq, Index entryPointIndex); + + /// Maybe write the artifact to the path (if set), or stdout (if there is no container or path) + SlangResult _maybeWriteArtifact(const String& path, IArtifact* artifact); + SlangResult _maybeWriteDebugArtifact( + TargetProgram* targetProgram, + const String& path, + IArtifact* artifact); + SlangResult _writeArtifact(const String& path, IArtifact* artifact); + + /// Adds any extra settings to complete a targetRequest + void _completeTargetRequest(UInt targetIndex); + + ISlangUnknown* getInterface(const Guid& guid); + + void generateOutput(ComponentType* program); + void generateOutput(TargetProgram* targetProgram); + + void init(); + + Session* m_session = nullptr; + RefPtr<Linkage> m_linkage; + DiagnosticSink m_sink; + RefPtr<FrontEndCompileRequest> m_frontEndReq; + RefPtr<ComponentType> m_specializedGlobalComponentType; + RefPtr<ComponentType> m_specializedGlobalAndEntryPointsComponentType; + List<RefPtr<ComponentType>> m_specializedEntryPoints; + + // For output + + RefPtr<StdWriters> m_writers; +}; + +} // namespace Slang diff --git a/source/slang/slang-entry-point.cpp b/source/slang/slang-entry-point.cpp new file mode 100644 index 000000000..0cce90646 --- /dev/null +++ b/source/slang/slang-entry-point.cpp @@ -0,0 +1,159 @@ +// slang-entry-point.cpp +#include "slang-entry-point.h" + +#include "slang-compiler.h" +#include "slang-mangle.h" + +namespace Slang +{ + +// +// EntryPoint +// + +ISlangUnknown* EntryPoint::getInterface(const Guid& guid) +{ + if (guid == slang::IEntryPoint::getTypeGuid()) + return static_cast<slang::IEntryPoint*>(this); + + return Super::getInterface(guid); +} + +RefPtr<EntryPoint> EntryPoint::create( + Linkage* linkage, + DeclRef<FuncDecl> funcDeclRef, + Profile profile) +{ + RefPtr<EntryPoint> entryPoint = + new EntryPoint(linkage, funcDeclRef.getName(), profile, funcDeclRef); + entryPoint->m_mangledName = getMangledName(linkage->getASTBuilder(), funcDeclRef); + return entryPoint; +} + +RefPtr<EntryPoint> EntryPoint::createDummyForPassThrough( + Linkage* linkage, + Name* name, + Profile profile) +{ + RefPtr<EntryPoint> entryPoint = new EntryPoint(linkage, name, profile, DeclRef<FuncDecl>()); + return entryPoint; +} + +RefPtr<EntryPoint> EntryPoint::createDummyForDeserialize( + Linkage* linkage, + Name* name, + Profile profile, + String mangledName) +{ + RefPtr<EntryPoint> entryPoint = new EntryPoint(linkage, name, profile, DeclRef<FuncDecl>()); + entryPoint->m_mangledName = mangledName; + return entryPoint; +} + +EntryPoint::EntryPoint(Linkage* linkage, Name* name, Profile profile, DeclRef<FuncDecl> funcDeclRef) + : ComponentType(linkage), m_name(name), m_profile(profile), m_funcDeclRef(funcDeclRef) +{ + // Collect any specialization parameters used by the entry point + // + _collectShaderParams(); +} + +Module* EntryPoint::getModule() +{ + return Slang::getModule(getFuncDecl()); +} + +Index EntryPoint::getSpecializationParamCount() +{ + return m_genericSpecializationParams.getCount() + m_existentialSpecializationParams.getCount(); +} + +SpecializationParam const& EntryPoint::getSpecializationParam(Index index) +{ + auto genericParamCount = m_genericSpecializationParams.getCount(); + if (index < genericParamCount) + { + return m_genericSpecializationParams[index]; + } + else + { + return m_existentialSpecializationParams[index - genericParamCount]; + } +} + +Index EntryPoint::getRequirementCount() +{ + // The only requirement of an entry point is the module that contains it. + // + // TODO: We will eventually want to support the case of an entry + // point nested in a `struct` type, in which case there should be + // a single requirement representing that outer type (so that multiple + // entry points nested under the same type can share the storage + // for parameters at that scope). + + // Note: the defensive coding is here because the + // "dummy" entry points we create for pass-through + // compilation will not have an associated module. + // + if (const auto module = getModule()) + { + return 1; + } + return 0; +} + +RefPtr<ComponentType> EntryPoint::getRequirement(Index index) +{ + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + SLANG_ASSERT(getModule()); + return getModule(); +} + +String EntryPoint::getEntryPointMangledName(Index index) +{ + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + + return m_mangledName; +} + +String EntryPoint::getEntryPointNameOverride(Index index) +{ + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + + return m_name ? m_name->text : ""; +} + +void EntryPoint::acceptVisitor( + ComponentTypeVisitor* visitor, + SpecializationInfo* specializationInfo) +{ + visitor->visitEntryPoint(this, as<EntryPointSpecializationInfo>(specializationInfo)); +} + +void EntryPoint::buildHash(DigestBuilder<SHA1>& builder) +{ + SLANG_UNUSED(builder); +} + +List<Module*> const& EntryPoint::getModuleDependencies() +{ + if (auto module = getModule()) + return module->getModuleDependencies(); + + static List<Module*> empty; + return empty; +} + +List<SourceFile*> const& EntryPoint::getFileDependencies() +{ + if (const auto module = getModule()) + return getModule()->getFileDependencies(); + + static List<SourceFile*> empty; + return empty; +} + +} // namespace Slang diff --git a/source/slang/slang-entry-point.h b/source/slang/slang-entry-point.h new file mode 100644 index 000000000..16499a542 --- /dev/null +++ b/source/slang/slang-entry-point.h @@ -0,0 +1,333 @@ +// slang-entry-point.h +#pragma once + +// +// This file provides the `EntryPoint` type, which is used +// to represent an entry point that has been identified +// and validated by the compiler front-end inside of some +// `Module`. +// +// Note that the `FrontEndEntryPointRequest` type, corresponding +// to use of the `-entry` command-line option, is not +// declared here, and instead comes from `slang-compile-request.h`. +// + +#include "slang-linkable.h" + +namespace Slang +{ + +/// Describes an entry point for the purposes of layout and code generation. +/// +/// This type intentionally does not distinguish between entry +/// points that were discovered because of a `[shader(...)]` attribute +/// and those that were identified by a `FrontEndEntryPointRequest` +/// (e.g., as a result of a `-entry` command-line option). The +/// intention is that *how* an `EntryPoint` came to be is a purely +/// front-end consideration, and the back-end of the compiler should +/// never depend on that information. +/// +/// This class also tracks any generic arguments to the entry point, +/// in the case that it is a specialization of a generic entry point. +/// +/// There is also a provision for creating a "dummy" entry point for +/// the purposes of pass-through compilation modes. Only the +/// `getName()` and `getProfile()` methods should be expected to +/// return useful data on pass-through entry points. +/// +class EntryPoint : public ComponentType, public slang::IEntryPoint +{ + typedef ComponentType Super; + +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + ISlangUnknown* getInterface(const Guid& guid); + + + // Forward `IComponentType` methods + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } + + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetCode(targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata( + entryPointIndex, + targetIndex, + outMetadata, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCompileResult( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::ICompileResult** outCompileResult, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCompileResult( + entryPointIndex, + targetIndex, + outCompileResult, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCompileResult( + SlangInt targetIndex, + slang::ICompileResult** outCompileResult, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetCompileResult(targetIndex, outCompileResult, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE + { + return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::link(outLinkedComponentType, outDiagnostics); + } + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + slang::IComponentType** outLinkedComponentType, + uint32_t count, + slang::CompilerOptionEntry* entries, + ISlangBlob** outDiagnostics) override + { + return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, + targetIndex, + outSharedLibrary, + outDiagnostics); + } + + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) SLANG_OVERRIDE + { + return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); + } + + virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE; + + /// Create an entry point that refers to the given function. + static RefPtr<EntryPoint> create( + Linkage* linkage, + DeclRef<FuncDecl> funcDeclRef, + Profile profile); + + /// Get the function decl-ref, including any generic arguments. + DeclRef<FuncDecl> getFuncDeclRef() { return m_funcDeclRef; } + + /// Get the function declaration (without generic arguments). + FuncDecl* getFuncDecl() { return m_funcDeclRef.getDecl(); } + + /// Get the name of the entry point + Name* getName() { return m_name; } + + /// Get the profile associated with the entry point + /// + /// Note: only the stage part of the profile is expected + /// to contain useful data, but certain legacy code paths + /// allow for "shader model" information to come via this path. + /// + Profile getProfile() { return m_profile; } + + /// Get the stage that the entry point is for. + Stage getStage() { return m_profile.getStage(); } + + /// Get the module that contains the entry point. + Module* getModule(); + + /// Get a list of modules that this entry point depends on. + /// + /// This will include the module that defines the entry point (see `getModule()`), + /// but may also include modules that are required by its generic type arguments. + /// + List<Module*> const& getModuleDependencies() + SLANG_OVERRIDE; // { return getModule()->getModuleDependencies(); } + List<SourceFile*> const& getFileDependencies() + SLANG_OVERRIDE; // { return getModule()->getFileDependencies(); } + + /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. + static RefPtr<EntryPoint> createDummyForPassThrough( + Linkage* linkage, + Name* name, + Profile profile); + + /// Create a dummy `EntryPoint` that stands in for a serialized entry point + static RefPtr<EntryPoint> createDummyForDeserialize( + Linkage* linkage, + Name* name, + Profile profile, + String mangledName); + + /// Get the number of existential type parameters for the entry point. + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE; + + /// Get the existential type parameter at `index`. + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; + + SpecializationParams const& getExistentialSpecializationParams() + { + return m_existentialSpecializationParams; + } + + Index getGenericSpecializationParamCount() { return m_genericSpecializationParams.getCount(); } + Index getExistentialSpecializationParamCount() + { + return m_existentialSpecializationParams.getCount(); + } + + /// Get an array of all entry-point shader parameters. + List<ShaderParamInfo> const& getShaderParams() { return m_shaderParams; } + + Index getEntryPointCount() SLANG_OVERRIDE { return 1; }; + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return this; + } + String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; + + Index getShaderParamCount() SLANG_OVERRIDE { return 0; } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return ShaderParamInfo(); + } + + class EntryPointSpecializationInfo : public SpecializationInfo + { + public: + DeclRef<FuncDecl> specializedFuncDeclRef; + List<ExpandedSpecializationArg> existentialSpecializationArgs; + }; + + SLANG_NO_THROW slang::FunctionReflection* SLANG_MCALL getFunctionReflection() SLANG_OVERRIDE + { + return (slang::FunctionReflection*)m_funcDeclRef.declRefBase; + } + +protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; + + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + +private: + EntryPoint(Linkage* linkage, Name* name, Profile profile, DeclRef<FuncDecl> funcDeclRef); + + void _collectGenericSpecializationParamsRec(Decl* decl); + void _collectShaderParams(); + + // The name of the entry point function (e.g., `main`) + // + Name* m_name = nullptr; + + // The declaration of the entry-point function itself. + // + DeclRef<FuncDecl> m_funcDeclRef; + + /// The mangled name of the entry point function + String m_mangledName; + + SpecializationParams m_genericSpecializationParams; + SpecializationParams m_existentialSpecializationParams; + + /// Information about entry-point parameters + List<ShaderParamInfo> m_shaderParams; + + // The profile that the entry point will be compiled for + // (this is a combination of the target stage, and also + // a feature level that sets capabilities) + // + // Note: the profile-version part of this should probably + // be moving towards deprecation, in favor of the version + // information (e.g., "Shader Model 5.1") always coming + // from the target, while the stage part is all that is + // intrinsic to the entry point. + // + Profile m_profile; +}; + +} // namespace Slang diff --git a/source/slang/slang-global-session.cpp b/source/slang/slang-global-session.cpp new file mode 100644 index 000000000..7b922dcf0 --- /dev/null +++ b/source/slang/slang-global-session.cpp @@ -0,0 +1,1217 @@ +// slang-global-session.cpp +#include "slang-global-session.h" + +#include "compiler-core/slang-artifact-desc-util.h" +#include "core/slang-archive-file-system.h" +#include "core/slang-performance-profiler.h" +#include "core/slang-type-convert-util.h" +#include "slang-check-impl.h" +#include "slang-compiler.h" +#include "slang-doc-ast.h" +#include "slang-doc-markdown-writer.h" +#include "slang-options.h" +#include "slang-parser.h" +#include "slang-serialize-ast.h" +#include "slang-serialize-container.h" +#include "slang-serialize-ir.h" + +extern Slang::String get_slang_cuda_prelude(); +extern Slang::String get_slang_cpp_prelude(); +extern Slang::String get_slang_hlsl_prelude(); + +namespace Slang +{ + +void Session::init() +{ + SLANG_ASSERT(BaseTypeInfo::check()); + +#if SLANG_ENABLE_IR_BREAK_ALLOC + // Read environment variable for IR debugging + StringBuilder irBreakEnv; + if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable( + UnownedStringSlice("SLANG_DEBUG_IR_BREAK"), + irBreakEnv))) + { + String envValue = irBreakEnv.produceString(); + if (envValue.getLength()) + { + _slangIRAllocBreak = stringToInt(envValue); + _slangIRPrintStackAtBreak = true; + } + } +#endif + + _initCodeGenTransitionMap(); + + ::memset(m_downstreamCompilerLocators, 0, sizeof(m_downstreamCompilerLocators)); + DownstreamCompilerUtil::setDefaultLocators(m_downstreamCompilerLocators); + m_downstreamCompilerSet = new DownstreamCompilerSet; + + m_completionTokenName = getNamePool()->getName("#?"); + + m_sharedLibraryLoader = DefaultSharedLibraryLoader::getSingleton(); + + // Set up the command line options + initCommandOptions(m_commandOptions); + + // Set up shared AST builder + m_sharedASTBuilder = new SharedASTBuilder; + m_sharedASTBuilder->init(this); + + // And the global ASTBuilder + auto builtinAstBuilder = m_sharedASTBuilder->getInnerASTBuilder(); + globalAstBuilder = builtinAstBuilder; + + // Make sure our source manager is initialized + builtinSourceManager.initialize(nullptr, nullptr); + + // Built in linkage uses the built in builder + m_builtinLinkage = new Linkage(this, builtinAstBuilder, nullptr); + m_builtinLinkage->m_optionSet.set(CompilerOptionName::DebugInformation, DebugInfoLevel::None); + + // Because the `Session` retains the builtin `Linkage`, + // we need to make sure that the parent pointer inside + // `Linkage` doesn't create a retain cycle. + // + // This operation ensures that the parent pointer will + // just be a raw pointer, so that the builtin linkage + // doesn't keep the parent session alive. + // + m_builtinLinkage->_stopRetainingParentSession(); + + // Create scopes for various language builtins. + // + // TODO: load these on-demand to avoid parsing + // the core module code for languages the user won't use. + + baseLanguageScope = builtinAstBuilder->create<Scope>(); + + // Will stay in scope as long as ASTBuilder + baseModuleDecl = + populateBaseLanguageModule(m_builtinLinkage->getASTBuilder(), baseLanguageScope); + + coreLanguageScope = builtinAstBuilder->create<Scope>(); + coreLanguageScope->nextSibling = baseLanguageScope; + + hlslLanguageScope = builtinAstBuilder->create<Scope>(); + hlslLanguageScope->nextSibling = coreLanguageScope; + + slangLanguageScope = builtinAstBuilder->create<Scope>(); + slangLanguageScope->nextSibling = hlslLanguageScope; + + glslLanguageScope = builtinAstBuilder->create<Scope>(); + glslLanguageScope->nextSibling = slangLanguageScope; + + glslModuleName = getNameObj("glsl"); + + { + for (Index i = 0; i < Index(SourceLanguage::CountOf); ++i) + { + m_defaultDownstreamCompilers[i] = PassThroughMode::None; + } + m_defaultDownstreamCompilers[Index(SourceLanguage::C)] = PassThroughMode::GenericCCpp; + m_defaultDownstreamCompilers[Index(SourceLanguage::CPP)] = PassThroughMode::GenericCCpp; + m_defaultDownstreamCompilers[Index(SourceLanguage::CUDA)] = PassThroughMode::NVRTC; + } + + // Set up default prelude code for target languages that need a prelude + m_languagePreludes[Index(SourceLanguage::CUDA)] = get_slang_cuda_prelude(); + m_languagePreludes[Index(SourceLanguage::CPP)] = get_slang_cpp_prelude(); + m_languagePreludes[Index(SourceLanguage::HLSL)] = get_slang_hlsl_prelude(); + + if (!spirvCoreGrammarInfo) + spirvCoreGrammarInfo = SPIRVCoreGrammarInfo::getEmbeddedVersion(); +} + +Session::~Session() +{ + // This is necessary because this ASTBuilder uses the SharedASTBuilder also owned by the + // session. If the SharedASTBuilder gets dtored before the globalASTBuilder it has a dangling + // pointer, which is referenced in the ASTBuilder dtor (likely) causing a crash. + // + // By destroying first we know it is destroyed, before the SharedASTBuilder. + globalAstBuilder.setNull(); + + // destroy modules next + coreModules = decltype(coreModules)(); +} + + +Module* Session::getBuiltinModule(slang::BuiltinModuleName name) +{ + auto info = getBuiltinModuleInfo(name); + auto builtinLinkage = getBuiltinLinkage(); + auto moduleNameObj = builtinLinkage->getNamePool()->getName(info.name); + RefPtr<Module> module; + if (builtinLinkage->mapNameToLoadedModules.tryGetValue(moduleNameObj, module)) + return module.get(); + return nullptr; +} + +void Session::_initCodeGenTransitionMap() +{ + // TODO(JS): Might want to do something about these in the future... + + // PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target); + // SourceLanguage getDefaultSourceLanguageForDownstreamCompiler(PassThroughMode compiler); + + // Set up the default ways to do compilations between code gen targets + auto& map = m_codeGenTransitionMap; + + // TODO(JS): There currently isn't a 'downstream compiler' for direct spirv output. If we did + // it would presumably a transition from SlangIR to SPIRV. + + // For C and C++ we default to use the 'genericCCpp' compiler + { + const CodeGenTarget sources[] = {CodeGenTarget::CSource, CodeGenTarget::CPPSource}; + for (auto source : sources) + { + // We *don't* add a default for host callable, as we will determine what is suitable + // depending on what is available. We prefer LLVM if that's available. If it's not we + // can use generic C/C++ compiler + + map.addTransition( + source, + CodeGenTarget::ShaderSharedLibrary, + PassThroughMode::GenericCCpp); + map.addTransition( + source, + CodeGenTarget::HostSharedLibrary, + PassThroughMode::GenericCCpp); + map.addTransition(source, CodeGenTarget::HostExecutable, PassThroughMode::GenericCCpp); + map.addTransition(source, CodeGenTarget::ObjectCode, PassThroughMode::GenericCCpp); + } + } + + + // Add all the straightforward transitions + map.addTransition(CodeGenTarget::CUDASource, CodeGenTarget::PTX, PassThroughMode::NVRTC); + map.addTransition(CodeGenTarget::HLSL, CodeGenTarget::DXBytecode, PassThroughMode::Fxc); + map.addTransition(CodeGenTarget::HLSL, CodeGenTarget::DXIL, PassThroughMode::Dxc); + map.addTransition(CodeGenTarget::GLSL, CodeGenTarget::SPIRV, PassThroughMode::Glslang); + map.addTransition(CodeGenTarget::Metal, CodeGenTarget::MetalLib, PassThroughMode::MetalC); + map.addTransition(CodeGenTarget::WGSL, CodeGenTarget::WGSLSPIRV, PassThroughMode::Tint); + // To assembly + map.addTransition(CodeGenTarget::SPIRV, CodeGenTarget::SPIRVAssembly, PassThroughMode::Glslang); + // We use glslang to turn SPIR-V into SPIR-V assembly. + map.addTransition( + CodeGenTarget::WGSLSPIRV, + CodeGenTarget::WGSLSPIRVAssembly, + PassThroughMode::Glslang); + map.addTransition(CodeGenTarget::DXIL, CodeGenTarget::DXILAssembly, PassThroughMode::Dxc); + map.addTransition( + CodeGenTarget::DXBytecode, + CodeGenTarget::DXBytecodeAssembly, + PassThroughMode::Fxc); + map.addTransition( + CodeGenTarget::MetalLib, + CodeGenTarget::MetalLibAssembly, + PassThroughMode::MetalC); +} + +void Session::addBuiltins(char const* sourcePath, char const* source) +{ + auto sourceBlob = StringBlob::moveCreate(String(source)); + + // TODO(tfoley): Add ability to directly new builtins to the appropriate scope + Module* module = nullptr; + addBuiltinSource(coreLanguageScope, sourcePath, sourceBlob, module); + if (module) + coreModules.add(module); +} + +void Session::setSharedLibraryLoader(ISlangSharedLibraryLoader* loader) +{ + // External API allows passing of nullptr to reset the loader + loader = loader ? loader : DefaultSharedLibraryLoader::getSingleton(); + + _setSharedLibraryLoader(loader); +} + +ISlangSharedLibraryLoader* Session::getSharedLibraryLoader() +{ + return (m_sharedLibraryLoader == DefaultSharedLibraryLoader::getSingleton()) + ? nullptr + : m_sharedLibraryLoader.get(); +} + +SlangResult Session::checkCompileTargetSupport(SlangCompileTarget inTarget) +{ + auto target = CodeGenTarget(inTarget); + + const PassThroughMode mode = getDownstreamCompilerRequiredForTarget(target); + return (mode != PassThroughMode::None) ? checkPassThroughSupport(SlangPassThrough(mode)) + : SLANG_OK; +} + +SlangResult Session::checkPassThroughSupport(SlangPassThrough inPassThrough) +{ + return checkExternalCompilerSupport(this, PassThroughMode(inPassThrough)); +} + +void Session::writeCoreModuleDoc(String config) +{ + ASTBuilder* astBuilder = getBuiltinLinkage()->getASTBuilder(); + SourceManager* sourceManager = getBuiltinSourceManager(); + + DiagnosticSink sink(sourceManager, Lexer::sourceLocationLexer); + + List<String> docStrings; + + // For all the modules add their doc output to docStrings + for (Module* m : coreModules) + { + RefPtr<ASTMarkup> markup(new ASTMarkup); + ASTMarkupUtil::extract(m->getModuleDecl(), sourceManager, &sink, markup); + + DocMarkdownWriter writer(markup, astBuilder, &sink); + auto rootPage = writer.writeAll(config.getUnownedSlice()); + File::writeAllText("toc.html", writer.writeTOC()); + rootPage->writeToDisk(); + rootPage->writeSummary(toSlice("summary.txt")); + } + ComPtr<ISlangBlob> diagnosticBlob; + sink.getBlobIfNeeded(diagnosticBlob.writeRef()); + if (diagnosticBlob && diagnosticBlob->getBufferSize() != 0) + { + // Write the diagnostic blob to stdout. + fprintf(stderr, "%s", (const char*)diagnosticBlob->getBufferPointer()); + } +} + +const char* getBuiltinModuleNameStr(slang::BuiltinModuleName name) +{ + const char* result = nullptr; + switch (name) + { + case slang::BuiltinModuleName::Core: + result = "core"; + break; + case slang::BuiltinModuleName::GLSL: + result = "glsl"; + break; + default: + SLANG_UNEXPECTED("Unknown builtin module"); + } + return result; +} + +TypeCheckingCache* Session::getTypeCheckingCache() +{ + return static_cast<TypeCheckingCache*>(m_typeCheckingCache.get()); +} + +Session::BuiltinModuleInfo Session::getBuiltinModuleInfo(slang::BuiltinModuleName name) +{ + Session::BuiltinModuleInfo result; + + result.name = getBuiltinModuleNameStr(name); + + switch (name) + { + case slang::BuiltinModuleName::Core: + result.languageScope = coreLanguageScope; + break; + case slang::BuiltinModuleName::GLSL: + result.name = "glsl"; + result.languageScope = glslLanguageScope; + break; + default: + SLANG_UNEXPECTED("Unknown builtin module"); + } + return result; +} + +SlangResult Session::compileCoreModule(slang::CompileCoreModuleFlags compileFlags) +{ + return compileBuiltinModule(slang::BuiltinModuleName::Core, compileFlags); +} + +void Session::getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName) +{ + switch (moduleName) + { + case slang::BuiltinModuleName::Core: + sb << (const char*)getCoreLibraryCode()->getBufferPointer() + << (const char*)getHLSLLibraryCode()->getBufferPointer() + << (const char*)getAutodiffLibraryCode()->getBufferPointer(); + break; + case slang::BuiltinModuleName::GLSL: + sb << (const char*)getGLSLLibraryCode()->getBufferPointer(); + break; + } +} + +SlangResult Session::compileBuiltinModule( + slang::BuiltinModuleName moduleName, + slang::CompileCoreModuleFlags compileFlags) +{ + SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder()); + +#ifdef _DEBUG + time_t beginTime = 0; + if (moduleName == slang::BuiltinModuleName::Core) + { + // Print a message in debug builds to notice the user that compiling the core module + // can take a while. + time(&beginTime); + fprintf(stderr, "Compiling core module on debug build, this can take a while.\n"); + } +#endif + BuiltinModuleInfo builtinModuleInfo = getBuiltinModuleInfo(moduleName); + auto moduleNameObj = m_builtinLinkage->getNamePool()->getName(builtinModuleInfo.name); + if (m_builtinLinkage->mapNameToLoadedModules.tryGetValue(moduleNameObj)) + { + // Already have the builtin module loaded + return SLANG_FAIL; + } + + StringBuilder moduleSrcBuilder; + getBuiltinModuleSource(moduleSrcBuilder, moduleName); + + // TODO(JS): Could make this return a SlangResult as opposed to exception + auto moduleSrcBlob = StringBlob::moveCreate(moduleSrcBuilder.produceString()); + Module* compiledModule = nullptr; + addBuiltinSource( + builtinModuleInfo.languageScope, + builtinModuleInfo.name, + moduleSrcBlob, + compiledModule); + + if (moduleName == slang::BuiltinModuleName::Core) + { + // We need to retain this AST so that we can use it in other code + // (Note that the `Scope` type does not retain the AST it points to) + coreModules.add(compiledModule); + } + + if (compileFlags & slang::CompileCoreModuleFlag::WriteDocumentation) + { + // Load config file first. + String configText; + if (SLANG_FAILED(File::readAllText("config.txt", configText))) + { + fprintf( + stderr, + "Error writing documentation: config file not found on current working " + "directory.\n"); + } + else + { + writeCoreModuleDoc(configText); + } + } + + finalizeSharedASTBuilder(); + +#ifdef _DEBUG + if (moduleName == slang::BuiltinModuleName::Core) + { + time_t endTime; + time(&endTime); + fprintf(stderr, "Compiling core module took %.2f seconds.\n", difftime(endTime, beginTime)); + } +#endif + return SLANG_OK; +} + +SlangResult Session::loadCoreModule(const void* coreModule, size_t coreModuleSizeInBytes) +{ + return loadBuiltinModule(slang::BuiltinModuleName::Core, coreModule, coreModuleSizeInBytes); +} + +SlangResult Session::loadBuiltinModule( + slang::BuiltinModuleName moduleName, + const void* moduleData, + size_t sizeInBytes) +{ + SLANG_PROFILE; + + + SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder()); + + BuiltinModuleInfo builtinModuleInfo = getBuiltinModuleInfo(moduleName); + auto nameObj = m_builtinLinkage->getNamePool()->getName(builtinModuleInfo.name); + if (m_builtinLinkage->mapNameToLoadedModules.containsKey(nameObj)) + { + // Already have a core module loaded + return SLANG_FAIL; + } + + // Make a file system to read it from + ComPtr<ISlangFileSystemExt> fileSystem; + SLANG_RETURN_ON_FAIL(loadArchiveFileSystem(moduleData, sizeInBytes, fileSystem)); + + // Let's try loading serialized modules and adding them + Module* module = nullptr; + SLANG_RETURN_ON_FAIL(_readBuiltinModule( + fileSystem, + builtinModuleInfo.languageScope, + builtinModuleInfo.name, + module)); + + if (moduleName == slang::BuiltinModuleName::Core) + { + // We need to retain this AST so that we can use it in other code + // (Note that the `Scope` type does not retain the AST it points to) + coreModules.add(module); + } + + finalizeSharedASTBuilder(); + return SLANG_OK; +} + +SlangResult Session::saveCoreModule(SlangArchiveType archiveType, ISlangBlob** outBlob) +{ + return saveBuiltinModule(slang::BuiltinModuleName::Core, archiveType, outBlob); +} + +SlangResult Session::saveBuiltinModule( + slang::BuiltinModuleName moduleTag, + SlangArchiveType archiveType, + ISlangBlob** outBlob) +{ + // If no builtin modules have been loaded, then there is + // nothing to save, and we fail immediately. + // + if (m_builtinLinkage->mapNameToLoadedModules.getCount() == 0) + { + return SLANG_FAIL; + } + + // The module will need to be looked up by its name, and + // will also be serialized out to a path with a matching name. + // + BuiltinModuleInfo moduleInfo = getBuiltinModuleInfo(moduleTag); + const char* moduleName = moduleInfo.name; + + // If we cannot find a loaded module in the linkage with + // the appropriate name, then for some reason it hasn't + // been loaded, and we fail. + // + RefPtr<Module> module; + m_builtinLinkage->mapNameToLoadedModules.tryGetValue( + getNameObj(UnownedStringSlice(moduleName)), + module); + if (!module) + { + return SLANG_FAIL; + } + + // AST serialization needs access to an AST builder, so + // we establish a current builder for the duration of + // the serialization process. + // + SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder()); + + // The serialized module will be represented as a logical + // file in an archive, so we create a logical file system + // to represent that archive. + // + ComPtr<ISlangMutableFileSystem> fileSystem; + SLANG_RETURN_ON_FAIL(createArchiveFileSystem(archiveType, fileSystem)); + // + // The created file system must support the `IArchiveFileSystem` + // interface (since we created it with `createArchiveFileSystem`). + // + auto archiveFileSystem = as<IArchiveFileSystem>(fileSystem); + if (!archiveFileSystem) + { + return SLANG_FAIL; + } + + // The output file name that we'll write to in that file system + // is just the builtin module name with a `.slang-module` suffix. + // + StringBuilder moduleFileName; + moduleFileName << moduleName << ".slang-module"; + + // The module serialization step has some options that we need + // to configure appropriately. + // + SerialContainerUtil::WriteOptions options; + // + // We want builtin modules to be saved with their source location + // information. + // + // And in order to work with source locations, the serialization + // process will also need access to the source manager that + // can translate locations into their humane format. + // + options.sourceManagerToUseWhenSerializingSourceLocs = m_builtinLinkage->getSourceManager(); + + // At this point we can finally delegate down to the next level, + // which handles the serialization of a Slang module into a + // byte stream. + // + OwnedMemoryStream stream(FileAccess::Write); + SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(module, options, &stream)); + auto contents = stream.getContents(); + + // Once the stream that represents the module has been written, we can + // write it to a file in the logical file system. + // + // TODO(tfoley): why can't the file system let us open the file for output? + // + SLANG_RETURN_ON_FAIL(fileSystem->saveFile( + moduleFileName.getBuffer(), + contents.getBuffer(), + contents.getCount())); + + // And finally, we can ask the archive file system to serialize itself + // out as a blob of bytes, which yields the final serialized representation + // of the module. + // + SLANG_RETURN_ON_FAIL(archiveFileSystem->storeArchive( + // The `true` here indicates that the blob that gets created should own + // its content, independent from the file system object itself; otherwise + // the file system might return a blob that shares storage with itself. + true, + outBlob)); + + return SLANG_OK; +} + +SlangResult Session::_readBuiltinModule( + ISlangFileSystem* fileSystem, + Scope* scope, + String moduleName, + Module*& outModule) +{ + // Get the name of the module + StringBuilder moduleFilename; + moduleFilename << moduleName << ".slang-module"; + + // Load it + ComPtr<ISlangBlob> fileContents; + SLANG_RETURN_ON_FAIL(fileSystem->loadFile(moduleFilename.getBuffer(), fileContents.writeRef())); + + RIFF::RootChunk const* rootChunk = RIFF::RootChunk::getFromBlob(fileContents); + if (!rootChunk) + { + return SLANG_FAIL; + } + + Linkage* linkage = getBuiltinLinkage(); + SourceManager* sourceManager = getBuiltinSourceManager(); + NamePool* sessionNamePool = &namePool; + + auto moduleChunk = ModuleChunk::find(rootChunk); + if (!moduleChunk) + return SLANG_FAIL; + + SHA1::Digest moduleDigest = moduleChunk->getDigest(); + + auto irChunk = moduleChunk->findIR(); + if (!irChunk) + return SLANG_FAIL; + + auto astChunk = moduleChunk->findAST(); + if (!astChunk) + return SLANG_FAIL; + + // Source location information is stored as a distinct + // chunk from the IR and AST, so we need to search for + // that chunk and then set up the information for use + // in the IR and AST deserialization (if we find anything). + // + RefPtr<SerialSourceLocReader> sourceLocReader; + if (auto debugChunk = DebugChunk::find(moduleChunk)) + { + SLANG_RETURN_ON_FAIL( + readSourceLocationsFromDebugChunk(debugChunk, sourceManager, sourceLocReader)); + } + + // At this point we create the `Module` object that will + // represent the builtin module we are reading, although + // it is still possible that deserialization will fail + // at one of the following steps. + // + auto astBuilder = linkage->getASTBuilder(); + RefPtr<Module> module(new Module(linkage, astBuilder)); + module->setName(moduleName); + module->setDigest(moduleDigest); + + + // Next, we set about deserializing the AST representation + // of the module. + // + auto moduleDecl = readSerializedModuleAST( + linkage, + astBuilder, + nullptr, // no sink + fileContents, + astChunk, + sourceLocReader, + SourceLoc()); + if (!moduleDecl) + { + return SLANG_FAIL; + } + moduleDecl->module = module; + module->setModuleDecl(moduleDecl); + + // After the AST module has been read in, we next look + // to deserialize the IR module. + // + RefPtr<IRModule> irModule; + SLANG_RETURN_ON_FAIL(readSerializedModuleIR(irChunk, this, sourceLocReader, irModule)); + + irModule->setName(module->getNameObj()); + module->setIRModule(irModule); + + // Put in the loaded module map + linkage->mapNameToLoadedModules.add(sessionNamePool->getName(moduleName), module); + + + // Add the resulting code to the appropriate scope + if (!scope->containerDecl) + { + // We are the first chunk of code to be loaded for this scope + scope->containerDecl = moduleDecl; + } + else + { + // We need to create a new scope to link into the whole thing + auto subScope = linkage->getASTBuilder()->create<Scope>(); + subScope->containerDecl = moduleDecl; + subScope->nextSibling = scope->nextSibling; + scope->nextSibling = subScope; + } + + outModule = module.get(); + + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +Session::queryInterface(SlangUUID const& uuid, void** outObject) +{ + if (uuid == Session::getTypeGuid()) + { + addReference(); + *outObject = static_cast<Session*>(this); + return SLANG_OK; + } + + if (uuid == ISlangUnknown::getTypeGuid() || uuid == IGlobalSession::getTypeGuid()) + { + addReference(); + *outObject = static_cast<slang::IGlobalSession*>(this); + return SLANG_OK; + } + + return SLANG_E_NO_INTERFACE; +} + +static size_t _getStructureSize(const uint8_t* src) +{ + size_t size = 0; + ::memcpy(&size, src, sizeof(size_t)); + return size; +} + +template<typename T> +static T makeFromSizeVersioned(const uint8_t* src) +{ + // The structure size must be size_t + SLANG_COMPILE_TIME_ASSERT(sizeof(((T*)src)->structureSize) == sizeof(size_t)); + + // The structureSize field *must* be the first element of T + // Ideally would use SLANG_COMPILE_TIME_ASSERT, but that doesn't work on gcc. + // Can't just assert, because determined to be a constant expression + { + auto offset = SLANG_OFFSET_OF(T, structureSize); + SLANG_ASSERT(offset == 0); + // Needed because offset is only 'used' by an assert + SLANG_UNUSED(offset); + } + + // The source size is held in the first element of T, and will be in the first bytes of src. + const size_t srcSize = _getStructureSize(src); + const size_t dstSize = sizeof(T); + + // If they are the same size, and appropriate alignment we can just cast and return + if (srcSize == dstSize && (size_t(src) & (alignof(T) - 1)) == 0) + { + return *(const T*)src; + } + + // Assumes T can default constructed sensibly + T dst; + + // It's structure size should be setup and should be dstSize + SLANG_ASSERT(dst.structureSize == dstSize); + + // The size to copy is the minimum on the two sizes + const auto copySize = std::min(srcSize, dstSize); + ::memcpy(&dst, src, copySize); + + // The final struct size is the destination size + dst.structureSize = dstSize; + + return dst; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +Session::createSession(slang::SessionDesc const& inDesc, slang::ISession** outSession) +{ + RefPtr<ASTBuilder> astBuilder(new ASTBuilder(m_sharedASTBuilder, "Session::astBuilder")); + slang::SessionDesc desc = makeFromSizeVersioned<slang::SessionDesc>((uint8_t*)&inDesc); + + RefPtr<Linkage> linkage = new Linkage(this, astBuilder, getBuiltinLinkage()); + + if (desc.skipSPIRVValidation) + { + linkage->m_optionSet.set(CompilerOptionName::SkipSPIRVValidation, true); + } + + { + std::lock_guard<std::mutex> lock(m_typeCheckingCacheMutex); + if (m_typeCheckingCache) + linkage->m_typeCheckingCache = + new TypeCheckingCache(*static_cast<TypeCheckingCache*>(m_typeCheckingCache.get())); + } + + Int searchPathCount = desc.searchPathCount; + for (Int ii = 0; ii < searchPathCount; ++ii) + { + linkage->addSearchPath(desc.searchPaths[ii]); + } + + Int macroCount = desc.preprocessorMacroCount; + for (Int ii = 0; ii < macroCount; ++ii) + { + auto& macro = desc.preprocessorMacros[ii]; + linkage->addPreprocessorDefine(macro.name, macro.value); + } + + if (desc.fileSystem) + { + linkage->setFileSystem(desc.fileSystem); + } + + if (desc.structureSize >= offsetof(slang::SessionDesc, enableEffectAnnotations)) + { + linkage->m_optionSet.set( + CompilerOptionName::EnableEffectAnnotations, + desc.enableEffectAnnotations); + } + + linkage->m_optionSet.load(desc.compilerOptionEntryCount, desc.compilerOptionEntries); + + if (!linkage->m_optionSet.hasOption(CompilerOptionName::MatrixLayoutColumn) && + !linkage->m_optionSet.hasOption(CompilerOptionName::MatrixLayoutRow)) + linkage->setMatrixLayoutMode(desc.defaultMatrixLayoutMode); + + { + const Int targetCount = desc.targetCount; + const uint8_t* targetDescPtr = reinterpret_cast<const uint8_t*>(desc.targets); + for (Int ii = 0; ii < targetCount; ++ii, targetDescPtr += _getStructureSize(targetDescPtr)) + { + const auto targetDesc = makeFromSizeVersioned<slang::TargetDesc>(targetDescPtr); + linkage->addTarget(targetDesc); + } + } + + // If any target requires debug info, then we will need to enable debug info when lowering to + // target-agnostic IR. The target-agnostic IR will only include debug info if the linkage IR + // options specify that it should, so make sure the linkage debug info level is greater than or + // equal to that of any target. + DebugInfoLevel linkageDebugInfoLevel = linkage->m_optionSet.getDebugInfoLevel(); + for (auto target : linkage->targets) + linkageDebugInfoLevel = + Math::Max(linkageDebugInfoLevel, target->getOptionSet().getDebugInfoLevel()); + linkage->m_optionSet.set(CompilerOptionName::DebugInformation, linkageDebugInfoLevel); + + // Add any referenced modules to the linkage + for (auto& option : linkage->m_optionSet.options) + { + if (option.key != CompilerOptionName::ReferenceModule) + continue; + for (auto& path : option.value) + { + DiagnosticSink sink; + ComPtr<IArtifact> artifact; + SlangResult result = createArtifactFromReferencedModule( + path.stringValue, + SourceLoc{}, + &sink, + artifact.writeRef()); + if (SLANG_FAILED(result)) + { + sink.diagnose(SourceLoc{}, Diagnostics::unableToReadFile, path.stringValue); + return result; + } + linkage->m_libModules.add(artifact); + } + } + + *outSession = asExternal(linkage.detach()); + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +Session::createCompileRequest(slang::ICompileRequest** outCompileRequest) +{ + auto req = new EndToEndCompileRequest(this); + + // Give it a ref (for output) + req->addRef(); + // Check it is what we think it should be + SLANG_ASSERT(req->debugGetReferenceCount() == 1); + + *outCompileRequest = req; + return SLANG_OK; +} + +SLANG_NO_THROW SlangProfileID SLANG_MCALL Session::findProfile(char const* name) +{ + return SlangProfileID(Slang::Profile::lookUp(name).raw); +} + +SLANG_NO_THROW SlangCapabilityID SLANG_MCALL Session::findCapability(char const* name) +{ + return SlangCapabilityID(Slang::findCapabilityName(UnownedTerminatedStringSlice(name))); +} + +SLANG_NO_THROW void SLANG_MCALL +Session::setDownstreamCompilerPath(SlangPassThrough inPassThrough, char const* path) +{ + PassThroughMode passThrough = PassThroughMode(inPassThrough); + SLANG_ASSERT( + int(passThrough) > int(PassThroughMode::None) && + int(passThrough) < int(PassThroughMode::CountOf)); + + if (m_downstreamCompilerPaths[int(passThrough)] != path) + { + // Make access redetermine compiler + resetDownstreamCompiler(passThrough); + // Set the path + m_downstreamCompilerPaths[int(passThrough)] = path; + } +} + +SLANG_NO_THROW void SLANG_MCALL +Session::setDownstreamCompilerPrelude(SlangPassThrough inPassThrough, char const* prelude) +{ + PassThroughMode downstreamCompiler = PassThroughMode(inPassThrough); + SLANG_ASSERT( + int(downstreamCompiler) > int(PassThroughMode::None) && + int(downstreamCompiler) < int(PassThroughMode::CountOf)); + const SourceLanguage sourceLanguage = + getDefaultSourceLanguageForDownstreamCompiler(downstreamCompiler); + setLanguagePrelude(SlangSourceLanguage(sourceLanguage), prelude); +} + +SLANG_NO_THROW void SLANG_MCALL +Session::getDownstreamCompilerPrelude(SlangPassThrough inPassThrough, ISlangBlob** outPrelude) +{ + PassThroughMode downstreamCompiler = PassThroughMode(inPassThrough); + SLANG_ASSERT( + int(downstreamCompiler) > int(PassThroughMode::None) && + int(downstreamCompiler) < int(PassThroughMode::CountOf)); + const SourceLanguage sourceLanguage = + getDefaultSourceLanguageForDownstreamCompiler(downstreamCompiler); + getLanguagePrelude(SlangSourceLanguage(sourceLanguage), outPrelude); +} + +SLANG_NO_THROW void SLANG_MCALL +Session::setLanguagePrelude(SlangSourceLanguage inSourceLanguage, char const* prelude) +{ + SourceLanguage sourceLanguage = SourceLanguage(inSourceLanguage); + SLANG_ASSERT( + int(sourceLanguage) > int(SourceLanguage::Unknown) && + int(sourceLanguage) < int(SourceLanguage::CountOf)); + + SLANG_ASSERT(sourceLanguage != SourceLanguage::Unknown); + + if (sourceLanguage != SourceLanguage::Unknown) + { + m_languagePreludes[int(sourceLanguage)] = prelude; + } +} + +SLANG_NO_THROW void SLANG_MCALL +Session::getLanguagePrelude(SlangSourceLanguage inSourceLanguage, ISlangBlob** outPrelude) +{ + SourceLanguage sourceLanguage = SourceLanguage(inSourceLanguage); + + *outPrelude = nullptr; + if (sourceLanguage != SourceLanguage::Unknown) + { + SLANG_ASSERT( + int(sourceLanguage) > int(SourceLanguage::Unknown) && + int(sourceLanguage) < int(SourceLanguage::CountOf)); + *outPrelude = + Slang::StringUtil::createStringBlob(m_languagePreludes[int(sourceLanguage)]).detach(); + } +} + +SLANG_NO_THROW const char* SLANG_MCALL Session::getBuildTagString() +{ + return ::Slang::getBuildTagString(); +} + +SLANG_NO_THROW SlangResult SLANG_MCALL Session::setDefaultDownstreamCompiler( + SlangSourceLanguage sourceLanguage, + SlangPassThrough defaultCompiler) +{ + if (DownstreamCompilerInfo::canCompile(defaultCompiler, sourceLanguage)) + { + m_defaultDownstreamCompilers[int(sourceLanguage)] = PassThroughMode(defaultCompiler); + return SLANG_OK; + } + return SLANG_FAIL; +} + +SlangPassThrough SLANG_MCALL +Session::getDefaultDownstreamCompiler(SlangSourceLanguage inSourceLanguage) +{ + SLANG_ASSERT(inSourceLanguage >= 0 && inSourceLanguage < SLANG_SOURCE_LANGUAGE_COUNT_OF); + auto sourceLanguage = SourceLanguage(inSourceLanguage); + return SlangPassThrough(m_defaultDownstreamCompilers[int(sourceLanguage)]); +} + +void Session::setDownstreamCompilerForTransition( + SlangCompileTarget source, + SlangCompileTarget target, + SlangPassThrough compiler) +{ + if (compiler == SLANG_PASS_THROUGH_NONE) + { + // Removing the transition means a default can be used + m_codeGenTransitionMap.removeTransition(CodeGenTarget(source), CodeGenTarget(target)); + } + else + { + m_codeGenTransitionMap.addTransition( + CodeGenTarget(source), + CodeGenTarget(target), + PassThroughMode(compiler)); + } +} + +SlangPassThrough Session::getDownstreamCompilerForTransition( + SlangCompileTarget inSource, + SlangCompileTarget inTarget) +{ + const CodeGenTarget source = CodeGenTarget(inSource); + const CodeGenTarget target = CodeGenTarget(inTarget); + + if (m_codeGenTransitionMap.hasTransition(source, target)) + { + return (SlangPassThrough)m_codeGenTransitionMap.getTransition(source, target); + } + + const auto desc = ArtifactDescUtil::makeDescForCompileTarget(inTarget); + + // Special case host-callable + if ((desc.kind == ArtifactKind::HostCallable) && + (source == CodeGenTarget::CSource || source == CodeGenTarget::CPPSource)) + { + // We prefer LLVM if it's available + if (const auto llvm = getOrLoadDownstreamCompiler(PassThroughMode::LLVM, nullptr)) + { + return SLANG_PASS_THROUGH_LLVM; + } + } + + // Use the legacy 'sourceLanguage' default mechanism. + // This says nothing about the target type, so it is *assumed* the target type is possible + // If not it will fail when trying to compile to an unknown target + const SourceLanguage sourceLanguage = + (SourceLanguage)TypeConvertUtil::getSourceLanguageFromTarget(inSource); + if (sourceLanguage != SourceLanguage::Unknown) + { + return getDefaultDownstreamCompiler(SlangSourceLanguage(sourceLanguage)); + } + + // Unknwon + return SLANG_PASS_THROUGH_NONE; +} + +IDownstreamCompiler* Session::getDownstreamCompiler(CodeGenTarget source, CodeGenTarget target) +{ + PassThroughMode compilerType = (PassThroughMode)getDownstreamCompilerForTransition( + SlangCompileTarget(source), + SlangCompileTarget(target)); + return getOrLoadDownstreamCompiler(compilerType, nullptr); +} + +SLANG_NO_THROW SlangResult SLANG_MCALL Session::setSPIRVCoreGrammar(char const* jsonPath) +{ + if (!jsonPath) + { + spirvCoreGrammarInfo = SPIRVCoreGrammarInfo::getEmbeddedVersion(); + SLANG_ASSERT(spirvCoreGrammarInfo); + } + else + { + SourceManager* sourceManager = getBuiltinSourceManager(); + SLANG_ASSERT(sourceManager); + DiagnosticSink sink(sourceManager, Lexer::sourceLocationLexer); + + String contents; + const auto readRes = File::readAllText(jsonPath, contents); + if (SLANG_FAILED(readRes)) + { + sink.diagnose(SourceLoc{}, Diagnostics::unableToReadFile, jsonPath); + return readRes; + } + const auto pathInfo = PathInfo::makeFromString(jsonPath); + const auto sourceFile = sourceManager->createSourceFileWithString(pathInfo, contents); + const auto sourceView = sourceManager->createSourceView(sourceFile, nullptr, SourceLoc()); + spirvCoreGrammarInfo = SPIRVCoreGrammarInfo::loadFromJSON(*sourceView, sink); + } + return spirvCoreGrammarInfo ? SLANG_OK : SLANG_FAIL; +} + +struct ParsedCommandLineData : public ISlangUnknown, public ComObject +{ + SLANG_COM_OBJECT_IUNKNOWN_ALL + + ISlangUnknown* getInterface(const Slang::Guid& guid) + { + if (guid == ISlangUnknown::getTypeGuid()) + return this; + return nullptr; + } + List<SerializedOptionsData> options; + List<slang::TargetDesc> targets; +}; + +SLANG_NO_THROW SlangResult SLANG_MCALL Session::parseCommandLineArguments( + int argc, + const char* const* argv, + slang::SessionDesc* outDesc, + ISlangUnknown** outAllocation) +{ + if (outDesc->structureSize < sizeof(slang::SessionDesc)) + return SLANG_E_BUFFER_TOO_SMALL; + RefPtr<ParsedCommandLineData> outData = new ParsedCommandLineData(); + RefPtr<EndToEndCompileRequest> tempReq = new EndToEndCompileRequest(this); + tempReq->processCommandLineArguments(argv, argc); + outData->options.setCount(1 + tempReq->getLinkage()->targets.getCount()); + int optionDataIndex = 0; + SerializedOptionsData& optionData = outData->options[optionDataIndex]; + optionDataIndex++; + tempReq->getOptionSet().serialize(&optionData); + tempReq->m_optionSetForDefaultTarget.serialize(&optionData); + for (auto target : tempReq->getLinkage()->targets) + { + slang::TargetDesc tdesc; + SerializedOptionsData& targetOptionData = outData->options[optionDataIndex]; + optionDataIndex++; + tempReq->getTargetOptionSet(target).serialize(&targetOptionData); + tdesc.compilerOptionEntryCount = (uint32_t)targetOptionData.entries.getCount(); + tdesc.compilerOptionEntries = targetOptionData.entries.getBuffer(); + outData->targets.add(tdesc); + } + outDesc->compilerOptionEntryCount = (uint32_t)optionData.entries.getCount(); + outDesc->compilerOptionEntries = optionData.entries.getBuffer(); + outDesc->targetCount = outData->targets.getCount(); + outDesc->targets = outData->targets.getBuffer(); + *outAllocation = outData.get(); + outData->addRef(); + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +Session::getSessionDescDigest(slang::SessionDesc* sessionDesc, ISlangBlob** outBlob) +{ + ComPtr<slang::ISession> tempSession; + createSession(*sessionDesc, tempSession.writeRef()); + auto linkage = static_cast<Linkage*>(tempSession.get()); + DigestBuilder<SHA1> digestBuilder; + linkage->buildHash(digestBuilder, -1); + auto blob = digestBuilder.finalize().toBlob(); + *outBlob = blob.detach(); + return SLANG_OK; +} + +void Session::addBuiltinSource( + Scope* scope, + String const& path, + ISlangBlob* sourceBlob, + Module*& outModule) +{ + SourceManager* sourceManager = getBuiltinSourceManager(); + + DiagnosticSink sink(sourceManager, Lexer::sourceLocationLexer); + + RefPtr<FrontEndCompileRequest> compileRequest = + new FrontEndCompileRequest(m_builtinLinkage, nullptr, &sink); + compileRequest->m_isCoreModuleCode = true; + + // Set the source manager on the sink + sink.setSourceManager(sourceManager); + // Make the linkage use the builtin source manager + Linkage* linkage = compileRequest->getLinkage(); + linkage->setSourceManager(sourceManager); + + Name* moduleName = getNamePool()->getName(path); + auto translationUnitIndex = + compileRequest->addTranslationUnit(SourceLanguage::Slang, moduleName); + + compileRequest->addTranslationUnitSourceBlob(translationUnitIndex, path, sourceBlob); + + SlangResult res = compileRequest->executeActionsInner(); + if (SLANG_FAILED(res)) + { + char const* diagnostics = sink.outputBuffer.getBuffer(); + fprintf(stderr, "%s", diagnostics); + + PlatformUtil::outputDebugMessage(diagnostics); + + SLANG_UNEXPECTED("error in Slang core module"); + } + + // Compiling the core module should not yield any warnings. + SLANG_ASSERT(sink.outputBuffer.getLength() == 0); + + // Extract the AST for the code we just parsed + auto module = compileRequest->translationUnits[translationUnitIndex]->getModule(); + auto moduleDecl = module->getModuleDecl(); + + // Extact documentation markup. + ASTMarkup markup; + ASTMarkupUtil::extract(moduleDecl, sourceManager, &sink, &markup); + markup.attachToAST(); + + // Put in the loaded module map + linkage->mapNameToLoadedModules.add(moduleName, module); + + // Add the resulting code to the appropriate scope + if (!scope->containerDecl) + { + // We are the first chunk of code to be loaded for this scope + scope->containerDecl = moduleDecl; + } + else + { + // We need to create a new scope to link into the whole thing + auto subScope = module->getASTBuilder()->create<Scope>(); + subScope->containerDecl = moduleDecl; + subScope->nextSibling = scope->nextSibling; + scope->nextSibling = subScope; + } + + outModule = module; +} + +SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough) +{ + // Check if the type is supported on this compile + if (passThrough == PassThroughMode::None) + { + // If no pass through -> that will always work! + return SLANG_OK; + } + + return session->getOrLoadDownstreamCompiler(passThrough, nullptr) ? SLANG_OK + : SLANG_E_NOT_FOUND; +} + +} // namespace Slang diff --git a/source/slang/slang-global-session.h b/source/slang/slang-global-session.h new file mode 100644 index 000000000..56984bdbd --- /dev/null +++ b/source/slang/slang-global-session.h @@ -0,0 +1,366 @@ +// slang-global-session.h +#pragma once + +// +// This file provides the `Session` type, and the implementation +// of the `slang::IGlobalSession` interface from the public API. +// +// TODO: there is an unfortunate and confusing situation +// where the public Slang API `ISession` type is implemented +// by the internal `Linkage` class, while the internal +// `Session` class implements the `IGlobalSession` interface +// from the public API. +// + +#include "../compiler-core/slang-downstream-compiler-set.h" +#include "../compiler-core/slang-downstream-compiler-util.h" +#include "../compiler-core/slang-downstream-compiler.h" +#include "../compiler-core/slang-spirv-core-grammar.h" +#include "../core/slang-command-options.h" +#include "slang-pass-through.h" +#include "slang-target.h" + +namespace Slang +{ + + +class CodeGenTransitionMap +{ +public: + struct Pair + { + typedef Pair ThisType; + SLANG_FORCE_INLINE bool operator==(const ThisType& rhs) const + { + return source == rhs.source && target == rhs.target; + } + SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + SLANG_FORCE_INLINE HashCode getHashCode() const + { + return combineHash(HashCode(source), HashCode(target)); + } + + CodeGenTarget source; + CodeGenTarget target; + }; + + void removeTransition(CodeGenTarget source, CodeGenTarget target) + { + m_map.remove(Pair{source, target}); + } + void addTransition(CodeGenTarget source, CodeGenTarget target, PassThroughMode compiler) + { + SLANG_ASSERT(source != target); + m_map.set(Pair{source, target}, compiler); + } + bool hasTransition(CodeGenTarget source, CodeGenTarget target) const + { + return m_map.containsKey(Pair{source, target}); + } + PassThroughMode getTransition(CodeGenTarget source, CodeGenTarget target) const + { + const Pair pair{source, target}; + auto value = m_map.tryGetValue(pair); + return value ? *value : PassThroughMode::None; + } + +protected: + Dictionary<Pair, PassThroughMode> m_map; +}; + +/// A global session for interaction with the Slang compiler. +/// +/// A `Session` provides a context for ongoing interaction +/// between a client application and the Slang API. Creating +/// a `Session` has an up-front cost that then makes creation +/// of one or more `Linkage`s cheaper. +/// +/// The main services provided by a `Session` are: +/// +/// * This class implements the `slang::IGobalSession` interface +/// from the public Slang API. +/// +/// * The global session provides a scope for a loading built-in +/// modules (including the core module), such that the costs +/// associated with those modules can be amortized across +/// multiple `Linkage`s using the same global session. +/// +/// * The global session provides a scope for interactions with +/// the surrounding OS, and application-specific customizations +/// related to it. This includes locating other tools such as +/// downstream compilers, which may be implemented either as +/// executables or shared libraries. +/// +/// * The global session provides a scope for application-specific +/// injection of custom source code to be treated like a builtin +/// module, or text to be used as the "prelude" for particular +/// downstream languages/compilers. This functionality should be +/// considered as a legacy and/or deprecated feature. +/// +/// * The global session provides various one-off services that +/// are specific to the `slangc` tool or to the build process +/// for Slang itself. Some of these have been exposed through the +/// `slang::IGlobalSession` interface, but should be *not* be +/// used by user applications. +/// +class Session : public RefObject, public slang::IGlobalSession +{ +public: + SLANG_COM_INTERFACE( + 0xd6b767eb, + 0xd786, + 0x4343, + {0x2a, 0x8c, 0x6d, 0xa0, 0x3d, 0x5a, 0xb4, 0x4a}) + + SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) + SLANG_OVERRIDE; + SLANG_REF_OBJECT_IUNKNOWN_ADD_REF + SLANG_REF_OBJECT_IUNKNOWN_RELEASE + + // slang::IGlobalSession + SLANG_NO_THROW SlangResult SLANG_MCALL + createSession(slang::SessionDesc const& desc, slang::ISession** outSession) override; + SLANG_NO_THROW SlangProfileID SLANG_MCALL findProfile(char const* name) override; + SLANG_NO_THROW void SLANG_MCALL + setDownstreamCompilerPath(SlangPassThrough passThrough, char const* path) override; + SLANG_NO_THROW void SLANG_MCALL + setDownstreamCompilerPrelude(SlangPassThrough inPassThrough, char const* prelude) override; + SLANG_NO_THROW void SLANG_MCALL + getDownstreamCompilerPrelude(SlangPassThrough inPassThrough, ISlangBlob** outPrelude) override; + SLANG_NO_THROW const char* SLANG_MCALL getBuildTagString() override; + SLANG_NO_THROW SlangResult SLANG_MCALL setDefaultDownstreamCompiler( + SlangSourceLanguage sourceLanguage, + SlangPassThrough defaultCompiler) override; + SLANG_NO_THROW SlangPassThrough SLANG_MCALL + getDefaultDownstreamCompiler(SlangSourceLanguage sourceLanguage) override; + + SLANG_NO_THROW void SLANG_MCALL + setLanguagePrelude(SlangSourceLanguage inSourceLanguage, char const* prelude) override; + SLANG_NO_THROW void SLANG_MCALL + getLanguagePrelude(SlangSourceLanguage inSourceLanguage, ISlangBlob** outPrelude) override; + + SLANG_NO_THROW SlangResult SLANG_MCALL + createCompileRequest(slang::ICompileRequest** outCompileRequest) override; + + SLANG_NO_THROW void SLANG_MCALL + addBuiltins(char const* sourcePath, char const* sourceString) override; + SLANG_NO_THROW void SLANG_MCALL + setSharedLibraryLoader(ISlangSharedLibraryLoader* loader) override; + SLANG_NO_THROW ISlangSharedLibraryLoader* SLANG_MCALL getSharedLibraryLoader() override; + SLANG_NO_THROW SlangResult SLANG_MCALL + checkCompileTargetSupport(SlangCompileTarget target) override; + SLANG_NO_THROW SlangResult SLANG_MCALL + checkPassThroughSupport(SlangPassThrough passThrough) override; + + void writeCoreModuleDoc(String config); + SLANG_NO_THROW SlangResult SLANG_MCALL + compileCoreModule(slang::CompileCoreModuleFlags flags) override; + SLANG_NO_THROW SlangResult SLANG_MCALL + loadCoreModule(const void* coreModule, size_t coreModuleSizeInBytes) override; + SLANG_NO_THROW SlangResult SLANG_MCALL + saveCoreModule(SlangArchiveType archiveType, ISlangBlob** outBlob) override; + + SLANG_NO_THROW SlangResult SLANG_MCALL compileBuiltinModule( + slang::BuiltinModuleName moduleName, + slang::CompileCoreModuleFlags flags) override; + SLANG_NO_THROW SlangResult SLANG_MCALL loadBuiltinModule( + slang::BuiltinModuleName moduleName, + const void* coreModule, + size_t coreModuleSizeInBytes) override; + SLANG_NO_THROW SlangResult SLANG_MCALL saveBuiltinModule( + slang::BuiltinModuleName moduleName, + SlangArchiveType archiveType, + ISlangBlob** outBlob) override; + + SLANG_NO_THROW SlangCapabilityID SLANG_MCALL findCapability(char const* name) override; + + SLANG_NO_THROW void SLANG_MCALL setDownstreamCompilerForTransition( + SlangCompileTarget source, + SlangCompileTarget target, + SlangPassThrough compiler) override; + SLANG_NO_THROW SlangPassThrough SLANG_MCALL getDownstreamCompilerForTransition( + SlangCompileTarget source, + SlangCompileTarget target) override; + SLANG_NO_THROW void SLANG_MCALL + getCompilerElapsedTime(double* outTotalTime, double* outDownstreamTime) override + { + *outDownstreamTime = m_downstreamCompileTime; + *outTotalTime = m_totalCompileTime; + } + + SLANG_NO_THROW SlangResult SLANG_MCALL setSPIRVCoreGrammar(char const* jsonPath) override; + + SLANG_NO_THROW SlangResult SLANG_MCALL parseCommandLineArguments( + int argc, + const char* const* argv, + slang::SessionDesc* outSessionDesc, + ISlangUnknown** outAllocation) override; + + SLANG_NO_THROW SlangResult SLANG_MCALL + getSessionDescDigest(slang::SessionDesc* sessionDesc, ISlangBlob** outBlob) override; + + /// Get the downstream compiler for a transition + IDownstreamCompiler* getDownstreamCompiler(CodeGenTarget source, CodeGenTarget target); + + // This needs to be atomic not because of contention between threads as `Session` is + // *not* multithreaded, but can be used exclusively on one thread at a time. + // The need for atomic is purely for visibility. If the session is used on a different + // thread we need to be sure any changes to m_epochId are visible to this thread. + std::atomic<Index> m_epochId = 1; + + Scope* baseLanguageScope = nullptr; + Scope* coreLanguageScope = nullptr; + Scope* hlslLanguageScope = nullptr; + Scope* slangLanguageScope = nullptr; + Scope* glslLanguageScope = nullptr; + Name* glslModuleName = nullptr; + + ModuleDecl* baseModuleDecl = nullptr; + List<RefPtr<Module>> coreModules; + + SourceManager builtinSourceManager; + + SourceManager* getBuiltinSourceManager() { return &builtinSourceManager; } + + // Name pool stuff for unique-ing identifiers + + NamePool namePool; + + NamePool* getNamePool() { return &namePool; } + Name* getNameObj(String name) { return namePool.getName(name); } + Name* tryGetNameObj(String name) { return namePool.tryGetName(name); } + // + + /// This AST Builder should only be used for creating AST nodes that are global across requests + /// not doing so could lead to memory being consumed but not used. + ASTBuilder* getGlobalASTBuilder() { return globalAstBuilder; } + void finalizeSharedASTBuilder(); + + RefPtr<ASTBuilder> globalAstBuilder; + + // Generated code for core module, etc. + String coreModulePath; + + ComPtr<ISlangBlob> coreLibraryCode; + // ComPtr<ISlangBlob> slangLibraryCode; + ComPtr<ISlangBlob> hlslLibraryCode; + ComPtr<ISlangBlob> glslLibraryCode; + ComPtr<ISlangBlob> autodiffLibraryCode; + + String getCoreModulePath(); + + ComPtr<ISlangBlob> getCoreLibraryCode(); + ComPtr<ISlangBlob> getHLSLLibraryCode(); + ComPtr<ISlangBlob> getAutodiffLibraryCode(); + ComPtr<ISlangBlob> getGLSLLibraryCode(); + + void getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName); + + RefPtr<SharedASTBuilder> m_sharedASTBuilder; + + SPIRVCoreGrammarInfo& getSPIRVCoreGrammarInfo() + { + if (!spirvCoreGrammarInfo) + setSPIRVCoreGrammar(nullptr); + SLANG_ASSERT(spirvCoreGrammarInfo); + return *spirvCoreGrammarInfo; + } + RefPtr<SPIRVCoreGrammarInfo> spirvCoreGrammarInfo; + + // + + void _setSharedLibraryLoader(ISlangSharedLibraryLoader* loader); + + /// Will try to load the library by specified name (using the set loader), if not one already + /// available. + IDownstreamCompiler* getOrLoadDownstreamCompiler(PassThroughMode type, DiagnosticSink* sink); + /// Will unload the specified shared library if it's currently loaded + void resetDownstreamCompiler(PassThroughMode type); + + /// Get the prelude associated with the language + const String& getPreludeForLanguage(SourceLanguage language) + { + return m_languagePreludes[int(language)]; + } + + /// Get the built in linkage -> handy to get the core module from + Linkage* getBuiltinLinkage() const { return m_builtinLinkage; } + + Module* getBuiltinModule(slang::BuiltinModuleName builtinModuleName); + + Name* getCompletionRequestTokenName() const { return m_completionTokenName; } + + void init(); + + void addBuiltinSource( + Scope* scope, + String const& path, + ISlangBlob* sourceBlob, + Module*& outModule); + ~Session(); + + void addDownstreamCompileTime(double time) { m_downstreamCompileTime += time; } + void addTotalCompileTime(double time) { m_totalCompileTime += time; } + + ComPtr<ISlangSharedLibraryLoader> + m_sharedLibraryLoader; ///< The shared library loader (never null) + + int m_downstreamCompilerInitialized = 0; + + RefPtr<DownstreamCompilerSet> + m_downstreamCompilerSet; ///< Information about all available downstream compilers. + ComPtr<IDownstreamCompiler> m_downstreamCompilers[int( + PassThroughMode::CountOf)]; ///< A downstream compiler for a pass through + DownstreamCompilerLocatorFunc m_downstreamCompilerLocators[int(PassThroughMode::CountOf)]; + Name* m_completionTokenName = nullptr; ///< The name of a completion request token. + + /// For parsing command line options + CommandOptions m_commandOptions; + + int m_typeDictionarySize = 0; + + RefPtr<RefObject> m_typeCheckingCache; + TypeCheckingCache* getTypeCheckingCache(); + std::mutex m_typeCheckingCacheMutex; + +private: + struct BuiltinModuleInfo + { + const char* name; + Scope* languageScope; + }; + + BuiltinModuleInfo getBuiltinModuleInfo(slang::BuiltinModuleName name); + + void _initCodeGenTransitionMap(); + + SlangResult _readBuiltinModule( + ISlangFileSystem* fileSystem, + Scope* scope, + String moduleName, + Module*& outModule); + + SlangResult _loadRequest(EndToEndCompileRequest* request, const void* data, size_t size); + + /// Linkage used for all built-in (core module) code. + RefPtr<Linkage> m_builtinLinkage; + + String + m_downstreamCompilerPaths[int(PassThroughMode::CountOf)]; ///< Paths for each pass through + String m_languagePreludes[int(SourceLanguage::CountOf)]; ///< Prelude for each source language + PassThroughMode m_defaultDownstreamCompilers[int(SourceLanguage::CountOf)]; + + // Describes a conversion from one code gen target (source) to another (target) + CodeGenTransitionMap m_codeGenTransitionMap; + + double m_downstreamCompileTime = 0.0; + double m_totalCompileTime = 0.0; +}; + +/* Returns SLANG_OK if pass through support is available */ +SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough); + +const char* getBuiltinModuleNameStr(slang::BuiltinModuleName name); + +} // namespace Slang diff --git a/source/slang/slang-linkable-impls.cpp b/source/slang/slang-linkable-impls.cpp new file mode 100644 index 000000000..d03ecb3ca --- /dev/null +++ b/source/slang/slang-linkable-impls.cpp @@ -0,0 +1,752 @@ +// slang-linkable-impls.cpp +#include "slang-linkable-impls.h" + +#include "slang-lower-to-ir.h" // for `generateIRForTypeConformance` +#include "slang-mangle.h" + +namespace Slang +{ + +// +// CompositeComponentType +// + +RefPtr<ComponentType> CompositeComponentType::create( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents) +{ + // TODO: We should ideally be caching the results of + // composition on the `linkage`, so that if we get + // asked for the same composite again later we re-use + // it rather than re-create it. + // + // Similarly, we might want to do some amount of + // work to "canonicalize" the input for composition. + // E.g., if the user does: + // + // X = compose(A,B); + // Y = compose(C,D); + // Z = compose(X,Y); + // + // W = compose(A, B, C, D); + // + // Then there is no observable difference between + // Z and W, so we might prefer to have them be identical. + + // If there is only a single child, then we should + // just return that child rather than create a dummy composite. + // + if (childComponents.getCount() == 1) + { + return childComponents[0]; + } + + return new CompositeComponentType(linkage, childComponents); +} + + +CompositeComponentType::CompositeComponentType( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents) + : ComponentType(linkage), m_childComponents(childComponents) +{ + HashSet<ComponentType*> requirementsSet; + for (auto child : childComponents) + { + child->enumerateModules([&](Module* module) { requirementsSet.add(module); }); + } + + for (auto child : childComponents) + { + auto childEntryPointCount = child->getEntryPointCount(); + for (Index cc = 0; cc < childEntryPointCount; ++cc) + { + m_entryPoints.add(child->getEntryPoint(cc)); + m_entryPointMangledNames.add(child->getEntryPointMangledName(cc)); + m_entryPointNameOverrides.add(child->getEntryPointNameOverride(cc)); + } + + auto childShaderParamCount = child->getShaderParamCount(); + for (Index pp = 0; pp < childShaderParamCount; ++pp) + { + m_shaderParams.add(child->getShaderParam(pp)); + } + + auto childSpecializationParamCount = child->getSpecializationParamCount(); + for (Index pp = 0; pp < childSpecializationParamCount; ++pp) + { + m_specializationParams.add(child->getSpecializationParam(pp)); + } + + for (auto module : child->getModuleDependencies()) + { + m_moduleDependencyList.addDependency(module); + } + for (auto sourceFile : child->getFileDependencies()) + { + m_fileDependencyList.addDependency(sourceFile); + } + + auto childRequirementCount = child->getRequirementCount(); + for (Index rr = 0; rr < childRequirementCount; ++rr) + { + auto childRequirement = child->getRequirement(rr); + if (!requirementsSet.contains(childRequirement)) + { + requirementsSet.add(childRequirement); + m_requirements.add(childRequirement); + } + } + } +} + +void CompositeComponentType::buildHash(DigestBuilder<SHA1>& builder) +{ + auto componentCount = getChildComponentCount(); + + for (Index i = 0; i < componentCount; ++i) + { + getChildComponent(i)->buildHash(builder); + } +} + +Index CompositeComponentType::getEntryPointCount() +{ + return m_entryPoints.getCount(); +} + +RefPtr<EntryPoint> CompositeComponentType::getEntryPoint(Index index) +{ + return m_entryPoints[index]; +} + +String CompositeComponentType::getEntryPointMangledName(Index index) +{ + return m_entryPointMangledNames[index]; +} + +String CompositeComponentType::getEntryPointNameOverride(Index index) +{ + return m_entryPointNameOverrides[index]; +} + +Index CompositeComponentType::getShaderParamCount() +{ + return m_shaderParams.getCount(); +} + +ShaderParamInfo CompositeComponentType::getShaderParam(Index index) +{ + return m_shaderParams[index]; +} + +Index CompositeComponentType::getSpecializationParamCount() +{ + return m_specializationParams.getCount(); +} + +SpecializationParam const& CompositeComponentType::getSpecializationParam(Index index) +{ + return m_specializationParams[index]; +} + +Index CompositeComponentType::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> CompositeComponentType::getRequirement(Index index) +{ + return m_requirements[index]; +} + +List<Module*> const& CompositeComponentType::getModuleDependencies() +{ + return m_moduleDependencyList.getModuleList(); +} + +List<SourceFile*> const& CompositeComponentType::getFileDependencies() +{ + return m_fileDependencyList.getFileList(); +} + +void CompositeComponentType::acceptVisitor( + ComponentTypeVisitor* visitor, + SpecializationInfo* specializationInfo) +{ + visitor->visitComposite(this, as<CompositeSpecializationInfo>(specializationInfo)); +} + +RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) +{ + SLANG_UNUSED(argCount); + + RefPtr<CompositeSpecializationInfo> specializationInfo = new CompositeSpecializationInfo(); + + Index offset = 0; + for (auto child : m_childComponents) + { + auto childParamCount = child->getSpecializationParamCount(); + SLANG_ASSERT(offset + childParamCount <= argCount); + + auto childInfo = child->_validateSpecializationArgs(args + offset, childParamCount, sink); + + specializationInfo->childInfos.add(childInfo); + + offset += childParamCount; + } + return specializationInfo; +} + +// +// SpecializedComponentType +// + +/// Utility type for collecting modules references by types/declarations +struct SpecializationArgModuleCollector : ComponentTypeVisitor +{ + HashSet<Module*> m_modulesSet; + List<Module*> m_modulesList; + + void addModule(Module* module) + { + m_modulesList.add(module); + m_modulesSet.add(module); + } + + void maybeAddModule(Module* module) + { + if (!module) + return; + if (m_modulesSet.contains(module)) + return; + + addModule(module); + } + + void collectReferencedModules(Decl* decl) + { + auto module = getModule(decl); + maybeAddModule(module); + } + + void collectReferencedModules(SubstitutionSet substitutions) + { + substitutions.forEachGenericSubstitution( + [this](GenericDecl*, Val::OperandView<Val> args) + { + for (auto arg : args) + { + collectReferencedModules(arg); + } + }); + } + + void collectReferencedModules(DeclRefBase* declRef) + { + collectReferencedModules(declRef->getDecl()); + collectReferencedModules(SubstitutionSet(declRef)); + } + + void collectReferencedModules(Type* type) + { + if (auto declRefType = as<DeclRefType>(type)) + { + collectReferencedModules(declRefType->getDeclRef()); + } + + // TODO: Handle non-decl-ref composite type cases + // (e.g., function types). + } + + void collectReferencedModules(Val* val) + { + if (auto type = as<Type>(val)) + { + collectReferencedModules(type); + } + else if (auto declRefVal = as<DeclRefIntVal>(val)) + { + collectReferencedModules(declRefVal->getDeclRef()); + } + + // TODO: other cases of values that could reference + // a declaration. + } + + void collectReferencedModules(List<ExpandedSpecializationArg> const& args) + { + for (auto arg : args) + { + collectReferencedModules(arg.val); + collectReferencedModules(arg.witness); + } + } + + // + // ComponentTypeVisitor methods + // + + void visitEntryPoint( + EntryPoint* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + SLANG_UNUSED(entryPoint); + + if (!specializationInfo) + return; + + collectReferencedModules(specializationInfo->specializedFuncDeclRef); + collectReferencedModules(specializationInfo->existentialSpecializationArgs); + } + + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + + void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) + SLANG_OVERRIDE + { + SLANG_UNUSED(module); + + if (!specializationInfo) + return; + + for (auto arg : specializationInfo->genericArgs) + { + collectReferencedModules(arg.argVal); + } + collectReferencedModules(specializationInfo->existentialArgs); + } + + void visitComposite( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } +}; + +SpecializedComponentType::SpecializedComponentType( + ComponentType* base, + ComponentType::SpecializationInfo* specializationInfo, + List<SpecializationArg> const& specializationArgs, + DiagnosticSink* sink) + : ComponentType(base->getLinkage()) + , m_base(base) + , m_specializationInfo(specializationInfo) + , m_specializationArgs(specializationArgs) +{ + m_optionSet.overrideWith(base->getOptionSet()); + + m_irModule = generateIRForSpecializedComponentType(this, sink); + + // We need to account for the fact that a specialized + // entity like `myShader<SomeType>` needs to not only + // depend on the module(s) that `myShader` depends on, + // but also on any modules that `SomeType` depends on. + // + // We will set up a "collector" type that will be + // used to build a list of these additional modules. + // + SpecializationArgModuleCollector moduleCollector; + + // We don't want to go adding additional requirements for + // modules that the base component type already includes, + // so we will add those to the set of modules in + // the collector before we starting trying to add others. + // + base->enumerateModules([&](Module* module) { moduleCollector.m_modulesSet.add(module); }); + + // In order to collect the additional modules, we need + // to inspect the specialization arguments and see what + // they depend on. + // + // Naively, it seems like we'd just want to iterate + // over `specializationArgs`, which gives the specialization + // arguments as the user supplied them. However, such + // an approach would have a subtle problem. + // + // If we have a generic entry point like: + // + // // In module A + // myShader<T : IThing> + // + // + // And the type `SomeType` that is being used as an argument doesn't + // directly conform to `IThing`: + // + // // In module B + // struct SomeType { ... } + // + // and the conformance of `SomeType` to `IThing` is + // coming from yet another module: + // + // // In module C + // import B; + // extension SomeType : IThing { ... } + // + // In this case, the specialized component for `myShader<SomeType>` + // needs to depend on all of: + // + // * Module A, because it defines `myShader` + // * Module B, because it defines `SomeType` + // * Module C, because it defines the conformance `SomeType : IThing` + // + // We thus need to iterate over a form of the specialization + // arguments that includes the "expanded" arguments like + // interface conformance witnesses that got added during + // semantic checking. + // + // The expanded arguments are being stored in the `specializationInfo` + // today (for use by downstream code generation), and the easiest + // way to walk that information and get to the leaf nodes where + // the expanded arguments are stored is to apply a visitor to + // the specialized component type we are in the middle of constructing. + // + moduleCollector.visitSpecialized(this); + + // Now that we've collected our additional information, we can + // start to build up the final lists for the specialized component type. + // + // The starting point for our lists comes from the base component type. + // + m_moduleDependencies = base->getModuleDependencies(); + m_fileDependencies = base->getFileDependencies(); + + Index baseRequirementCount = base->getRequirementCount(); + for (Index r = 0; r < baseRequirementCount; r++) + { + m_requirements.add(base->getRequirement(r)); + } + + // The specialized component type will need to have additional + // dependencies and requirements based on the modules that + // were collected when looking at the specialization arguments. + + // We want to avoid adding the same file dependency more than once. + // + HashSet<SourceFile*> fileDependencySet; + for (SourceFile* sourceFile : m_fileDependencies) + fileDependencySet.add(sourceFile); + + for (auto module : moduleCollector.m_modulesList) + { + // The specialized component type will have an open (unsatisfied) + // requirement for each of the modules that its specialization + // arguments need. + // + // Note: what this means in practice is that the component type + // records that the given module(s) will need to be linked in + // before final code can be generated, but it importantly + // does not dictate the final placement of the parameters from + // those modules in the layout. + // + m_requirements.add(module); + + // The speciialized component type will also have a dependency + // on all the files that any of the modules involved in + // it depend on (including those that are required but not + // yet linked in). + // + // The file path information is what a client would need to + // use to decide if kernel code is out of date compared to + // source files, so we want to include anything that could + // affect the validity of generated code. + // + for (SourceFile* sourceFile : module->getFileDependencies()) + { + if (fileDependencySet.contains(sourceFile)) + continue; + fileDependencySet.add(sourceFile); + m_fileDependencies.add(sourceFile); + } + + // Finalyl we also add the module for the specialization arguments + // to the list of modules that would be used for legacy lookup + // operations where we need an implicit/default scope to use + // and want it to be expansive. + // + // TODO: This stuff really isn't worth keeping around long + // term, and we should ditch the entire "legacy lookup" idea. + // + m_moduleDependencies.add(module); + } + + // Because we are specializing shader code, the mangled entry + // point names for this component type may be different than + // for the base component type (e.g., the mangled name for `f<int>` + // is different than that that of the generic `f` function + // itself). + // + // We will compute the mangled names of all the entry points and + // store them here, so that we don't have to do it on the fly. + // Because the `ComponentType` structure is hierarchical, we + // need to use a recursive visitor to compute the names, + // and we will define that visitor locally: + // + struct EntryPointMangledNameCollector : ComponentTypeVisitor + { + List<String>* mangledEntryPointNames; + List<String>* entryPointNameOverrides; + + void visitEntryPoint( + EntryPoint* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + auto funcDeclRef = entryPoint->getFuncDeclRef(); + if (specializationInfo) + funcDeclRef = specializationInfo->specializedFuncDeclRef; + + (*mangledEntryPointNames).add(getMangledName(m_astBuilder, funcDeclRef)); + (*entryPointNameOverrides).add(entryPoint->getEntryPointNameOverride(0)); + } + + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + (*entryPointNameOverrides).getLast() = entryPoint->getEntryPointNameOverride(0); + } + + void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE {} + void visitComposite( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } + EntryPointMangledNameCollector(ASTBuilder* astBuilder) + : m_astBuilder(astBuilder) + { + } + ASTBuilder* m_astBuilder; + }; + + // With the visitor defined, we apply it to ourself to compute + // and collect the mangled entry point names. + // + EntryPointMangledNameCollector collector(getLinkage()->getASTBuilder()); + collector.mangledEntryPointNames = &m_entryPointMangledNames; + collector.entryPointNameOverrides = &m_entryPointNameOverrides; + collector.visitSpecialized(this); +} + +void SpecializedComponentType::buildHash(DigestBuilder<SHA1>& builder) +{ + auto specializationArgCount = getSpecializationArgCount(); + for (Index i = 0; i < specializationArgCount; ++i) + { + auto specializationArg = getSpecializationArg(i); + auto argString = specializationArg.val->toString(); + builder.append(argString); + } + + getBaseComponentType()->buildHash(builder); +} + +void SpecializedComponentType::acceptVisitor( + ComponentTypeVisitor* visitor, + SpecializationInfo* specializationInfo) +{ + SLANG_ASSERT(specializationInfo == nullptr); + SLANG_UNUSED(specializationInfo); + visitor->visitSpecialized(this); +} + +Index SpecializedComponentType::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index) +{ + return m_requirements[index]; +} + +String SpecializedComponentType::getEntryPointMangledName(Index index) +{ + return m_entryPointMangledNames[index]; +} + +String SpecializedComponentType::getEntryPointNameOverride(Index index) +{ + return m_entryPointNameOverrides[index]; +} + +// +// RenamedEntryPointComponentType +// + +RenamedEntryPointComponentType::RenamedEntryPointComponentType(ComponentType* base, String newName) + : ComponentType(base->getLinkage()), m_base(base), m_entryPointNameOverride(newName) +{ +} + +void RenamedEntryPointComponentType::acceptVisitor( + ComponentTypeVisitor* visitor, + SpecializationInfo* specializationInfo) +{ + visitor->visitRenamedEntryPoint( + this, + as<EntryPoint::EntryPointSpecializationInfo>(specializationInfo)); +} + +void RenamedEntryPointComponentType::buildHash(DigestBuilder<SHA1>& builder) +{ + SLANG_UNUSED(builder); +} + +// +// TypeConformance +// + +TypeConformance::TypeConformance( + Linkage* linkage, + SubtypeWitness* witness, + Int confomrmanceIdOverride, + DiagnosticSink* sink) + : ComponentType(linkage) + , m_subtypeWitness(witness) + , m_conformanceIdOverride(confomrmanceIdOverride) +{ + addDepedencyFromWitness(witness); + m_irModule = generateIRForTypeConformance(this, m_conformanceIdOverride, sink); +} + +void TypeConformance::addDepedencyFromWitness(SubtypeWitness* witness) +{ + if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness)) + { + auto declModule = getModule(declaredWitness->getDeclRef().getDecl()); + m_moduleDependencyList.addDependency(declModule); + m_fileDependencyList.addDependency(declModule); + if (m_requirementSet.add(declModule)) + { + m_requirements.add(declModule); + } + // TODO: handle the specialization arguments in declaredWitness->declRef.substitutions. + } + else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness)) + { + addDepedencyFromWitness(transitiveWitness->getMidToSup()); + addDepedencyFromWitness(transitiveWitness->getSubToMid()); + } + else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness)) + { + auto componentCount = conjunctionWitness->getComponentCount(); + for (Index i = 0; i < componentCount; ++i) + { + auto w = as<SubtypeWitness>(conjunctionWitness->getComponentWitness(i)); + if (w) + addDepedencyFromWitness(w); + } + } +} + +ISlangUnknown* TypeConformance::getInterface(const Guid& guid) +{ + if (guid == slang::ITypeConformance::getTypeGuid()) + return static_cast<slang::ITypeConformance*>(this); + + return Super::getInterface(guid); +} + +void TypeConformance::buildHash(DigestBuilder<SHA1>& builder) +{ + // TODO: Implement some kind of hashInto for Val then replace this + auto subtypeWitness = m_subtypeWitness->toString(); + + builder.append(subtypeWitness); + builder.append(m_conformanceIdOverride); +} + +List<Module*> const& TypeConformance::getModuleDependencies() +{ + return m_moduleDependencyList.getModuleList(); +} + +List<SourceFile*> const& TypeConformance::getFileDependencies() +{ + return m_fileDependencyList.getFileList(); +} + +Index TypeConformance::getRequirementCount() +{ + return m_requirements.getCount(); +} + +RefPtr<ComponentType> TypeConformance::getRequirement(Index index) +{ + return m_requirements[index]; +} + +void TypeConformance::acceptVisitor( + ComponentTypeVisitor* visitor, + ComponentType::SpecializationInfo* specializationInfo) +{ + SLANG_UNUSED(specializationInfo); + visitor->visitTypeConformance(this); +} + +RefPtr<ComponentType::SpecializationInfo> TypeConformance::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) +{ + SLANG_UNUSED(args); + SLANG_UNUSED(argCount); + SLANG_UNUSED(sink); + return nullptr; +} + +// +// ComponentTypeVisitor +// + +void ComponentTypeVisitor::visitChildren( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo) +{ + auto childCount = composite->getChildComponentCount(); + for (Index ii = 0; ii < childCount; ++ii) + { + auto child = composite->getChildComponent(ii); + auto childSpecializationInfo = + specializationInfo ? specializationInfo->childInfos[ii] : nullptr; + + child->acceptVisitor(this, childSpecializationInfo); + } +} + +void ComponentTypeVisitor::visitChildren(SpecializedComponentType* specialized) +{ + specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo()); +} + +} // namespace Slang diff --git a/source/slang/slang-linkable-impls.h b/source/slang/slang-linkable-impls.h new file mode 100644 index 000000000..68a16587d --- /dev/null +++ b/source/slang/slang-linkable-impls.h @@ -0,0 +1,566 @@ +// slang-linkable-impl.h +#pragma once + +// +// This file declares various implementations of linkable +// objects (subclasses of `ComponentType`). +// +// Note that the base `ComponentType` class is declared +// in `slang-linkable.h`. +// +// Note that the most important two classes of linkable +// objects, `Module`s and `EntryPoint`s, have their own +// headers: `slang-module.h` and `slang-entry-point.h`, +// respectively. +// + +#include "slang-entry-point.h" +#include "slang-linkable.h" +#include "slang-module.h" + +namespace Slang +{ +/// A component type built up from other component types. +class CompositeComponentType : public ComponentType +{ +public: + static RefPtr<ComponentType> create( + Linkage* linkage, + List<RefPtr<ComponentType>> const& childComponents); + + virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE; + + List<RefPtr<ComponentType>> const& getChildComponents() { return m_childComponents; }; + Index getChildComponentCount() { return m_childComponents.getCount(); } + RefPtr<ComponentType> getChildComponent(Index index) { return m_childComponents[index]; } + + Index getEntryPointCount() SLANG_OVERRIDE; + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE; + String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; + + Index getShaderParamCount() SLANG_OVERRIDE; + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE; + + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE; + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; + + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE; + List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE; + + class CompositeSpecializationInfo : public SpecializationInfo + { + public: + List<RefPtr<SpecializationInfo>> childInfos; + }; + +protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; + + + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + +public: + CompositeComponentType(Linkage* linkage, List<RefPtr<ComponentType>> const& childComponents); + +private: + List<RefPtr<ComponentType>> m_childComponents; + + // The following arrays hold the concatenated entry points, parameters, + // etc. from the child components. This approach allows for reasonably + // fast (constant time) access through operations like `getShaderParam`, + // but means that the memory usage of a composite is proportional to + // the sum of the memory usage of the children, rather than being fixed + // by the number of children (as it would be if we just stored + // `m_childComponents`). + // + // TODO: We could conceivably build some O(numChildren) arrays that + // support binary-search to provide logarithmic-time access to entry + // points, parameters, etc. while giving a better overall memory usage. + // + List<EntryPoint*> m_entryPoints; + List<String> m_entryPointMangledNames; + List<String> m_entryPointNameOverrides; + List<ShaderParamInfo> m_shaderParams; + List<SpecializationParam> m_specializationParams; + List<ComponentType*> m_requirements; + + ModuleDependencyList m_moduleDependencyList; + FileDependencyList m_fileDependencyList; +}; + +/// A component type created by specializing another component type. +class SpecializedComponentType : public ComponentType +{ +public: + SpecializedComponentType( + ComponentType* base, + SpecializationInfo* specializationInfo, + List<SpecializationArg> const& specializationArgs, + DiagnosticSink* sink); + + virtual void buildHash(DigestBuilder<SHA1>& builer) SLANG_OVERRIDE; + + /// Get the base (unspecialized) component type that is being specialized. + RefPtr<ComponentType> getBaseComponentType() { return m_base; } + + RefPtr<SpecializationInfo> getSpecializationInfo() { return m_specializationInfo; } + + /// Get the number of arguments supplied for existential type parameters. + /// + /// Note that the number of arguments may not match the number of parameters. + /// In particular, an unspecialized entry point may have many parameters, but zero arguments. + Index getSpecializationArgCount() { return m_specializationArgs.getCount(); } + + /// Get the existential type argument (type and witness table) at `index`. + SpecializationArg const& getSpecializationArg(Index index) + { + return m_specializationArgs[index]; + } + + /// Get an array of all existential type arguments. + SpecializationArg const* getSpecializationArgs() { return m_specializationArgs.getBuffer(); } + + Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE + { + return m_base->getEntryPoint(index); + } + String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; + + Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE + { + return m_base->getShaderParam(index); + } + + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + static SpecializationParam dummy; + return dummy; + } + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; + + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies; } + List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencies; } + + RefPtr<IRModule> getIRModule() { return m_irModule; } + + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; + +protected: + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE + { + SLANG_UNUSED(args); + SLANG_UNUSED(argCount); + SLANG_UNUSED(sink); + return nullptr; + } + +private: + RefPtr<ComponentType> m_base; + RefPtr<SpecializationInfo> m_specializationInfo; + SpecializationArgs m_specializationArgs; + RefPtr<IRModule> m_irModule; + + List<String> m_entryPointMangledNames; + List<String> m_entryPointNameOverrides; + + List<Module*> m_moduleDependencies; + List<SourceFile*> m_fileDependencies; + List<RefPtr<ComponentType>> m_requirements; +}; + +class RenamedEntryPointComponentType : public ComponentType +{ +public: + using Super = ComponentType; + + RenamedEntryPointComponentType(ComponentType* base, String newName); + + ComponentType* getBase() { return m_base.Ptr(); } + + // Forward `IComponentType` methods + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } + + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::link(outLinkedComponentType, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, + targetIndex, + outSharedLibrary, + outDiagnostics); + } + + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE + { + return m_base->getModuleDependencies(); + } + List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE + { + return m_base->getFileDependencies(); + } + + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE + { + return m_base->getSpecializationParamCount(); + } + + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE + { + return m_base->getSpecializationParam(index); + } + + Index getRequirementCount() SLANG_OVERRIDE { return m_base->getRequirementCount(); } + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE + { + return m_base->getRequirement(index); + } + Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE + { + return m_base->getEntryPoint(index); + } + String getEntryPointMangledName(Index index) SLANG_OVERRIDE + { + return m_base->getEntryPointMangledName(index); + } + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + return m_entryPointNameOverride; + } + + Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE + { + return m_base->getShaderParam(index); + } + + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; + + virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE; + +private: + RefPtr<ComponentType> m_base; + String m_entryPointNameOverride; + +protected: + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE + { + return m_base->_validateSpecializationArgsImpl(args, argCount, sink); + } +}; + +class TypeConformance : public ComponentType, public slang::ITypeConformance +{ + typedef ComponentType Super; + +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + ISlangUnknown* getInterface(const Guid& guid); + + TypeConformance( + Linkage* linkage, + SubtypeWitness* witness, + Int confomrmanceIdOverride, + DiagnosticSink* sink); + + // Forward `IComponentType` methods + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } + + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetCode(targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata( + entryPointIndex, + targetIndex, + outMetadata, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCompileResult( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::ICompileResult** outCompileResult, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCompileResult( + entryPointIndex, + targetIndex, + outCompileResult, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCompileResult( + SlangInt targetIndex, + slang::ICompileResult** outCompileResult, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetCompileResult(targetIndex, outCompileResult, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE + { + return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::link(outLinkedComponentType, outDiagnostics); + } + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + slang::IComponentType** outLinkedComponentType, + uint32_t count, + slang::CompilerOptionEntry* entries, + ISlangBlob** outDiagnostics) override + { + return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, + targetIndex, + outSharedLibrary, + outDiagnostics); + } + + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) SLANG_OVERRIDE + { + return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); + } + + virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE; + + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE; + List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE; + + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } + + /// Get the existential type parameter at `index`. + SpecializationParam const& getSpecializationParam(Index /*index*/) SLANG_OVERRIDE + { + static SpecializationParam emptyParam; + return emptyParam; + } + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; + Index getEntryPointCount() SLANG_OVERRIDE { return 0; }; + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return nullptr; + } + String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; } + String getEntryPointNameOverride(Index /*index*/) SLANG_OVERRIDE { return ""; } + + Index getShaderParamCount() SLANG_OVERRIDE { return 0; } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return ShaderParamInfo(); + } + + SubtypeWitness* getSubtypeWitness() { return m_subtypeWitness; } + IRModule* getIRModule() { return m_irModule.Ptr(); } + +protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; + + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + +private: + SubtypeWitness* m_subtypeWitness; + ModuleDependencyList m_moduleDependencyList; + FileDependencyList m_fileDependencyList; + List<RefPtr<Module>> m_requirements; + HashSet<Module*> m_requirementSet; + RefPtr<IRModule> m_irModule; + Int m_conformanceIdOverride; + void addDepedencyFromWitness(SubtypeWitness* witness); +}; + +/// A visitor for use with `ComponentType`s, allowing dispatch over the concrete subclasses. +class ComponentTypeVisitor +{ +public: + // The following methods should be overriden in a concrete subclass + // to customize how it acts on each of the concrete types of component. + // + // In cases where the application wants to simply "recurse" on a + // composite, specialized, or legacy component type it can use + // the `visitChildren` methods below. + // + virtual void visitEntryPoint( + EntryPoint* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; + virtual void visitModule( + Module* module, + Module::ModuleSpecializationInfo* specializationInfo) = 0; + virtual void visitComposite( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; + virtual void visitSpecialized(SpecializedComponentType* specialized) = 0; + virtual void visitTypeConformance(TypeConformance* conformance) = 0; + virtual void visitRenamedEntryPoint( + RenamedEntryPointComponentType* renamedEntryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; + +protected: + // These helpers can be used to recurse into the logical children of a + // component type, and are useful for the common case where a visitor + // only cares about a few leaf cases. + // + void visitChildren( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo); + void visitChildren(SpecializedComponentType* specialized); +}; + +} // namespace Slang diff --git a/source/slang/slang-linkable.cpp b/source/slang/slang-linkable.cpp new file mode 100644 index 000000000..da4cec823 --- /dev/null +++ b/source/slang/slang-linkable.cpp @@ -0,0 +1,1037 @@ +// slang-linkable.cpp +#include "slang-linkable.h" + +#include "compiler-core/slang-artifact-container-util.h" +#include "compiler-core/slang-artifact-desc-util.h" +#include "compiler-core/slang-artifact-impl.h" +#include "core/slang-char-util.h" +#include "core/slang-memory-file-system.h" +#include "slang-check-impl.h" +#include "slang-compiler.h" +#include "slang-mangle.h" + +namespace Slang +{ + +// +// ModuleDependencyList +// + +void ModuleDependencyList::addDependency(Module* module) +{ + // If we depend on a module, then we depend on everything it depends on. + // + // Note: We are processing these sub-depenencies before adding the + // `module` itself, so that in the common case a module will always + // appear *after* everything it depends on. + // + // However, this rule is being violated in the compiler right now because + // the modules for hte top-level translation units of a compile request + // will be added to the list first (using `addLeafDependency`) to + // maintain compatibility with old behavior. This may be fixed later. + // + for (auto subDependency : module->getModuleDependencyList()) + { + _addDependency(subDependency); + } + _addDependency(module); +} + +void ModuleDependencyList::addLeafDependency(Module* module) +{ + _addDependency(module); +} + +void ModuleDependencyList::_addDependency(Module* module) +{ + if (m_moduleSet.contains(module)) + return; + + m_moduleList.add(module); + m_moduleSet.add(module); +} + +// +// FileDependencyList +// + +void FileDependencyList::addDependency(SourceFile* sourceFile) +{ + if (m_fileSet.contains(sourceFile)) + return; + + m_fileList.add(sourceFile); + m_fileSet.add(sourceFile); +} + +void FileDependencyList::addDependency(Module* module) +{ + for (SourceFile* sourceFile : module->getFileDependencyList()) + { + addDependency(sourceFile); + } +} + +// +// ComponentType +// + +ComponentType::ComponentType(Linkage* linkage) + : m_linkage(linkage) +{ +} + +ComponentType* asInternal(slang::IComponentType* inComponentType) +{ + // Note: we use a `queryInterface` here instead of just a `static_cast` + // to ensure that the `IComponentType` we get is the preferred/canonical + // one, which shares its address with the `ComponentType`. + // + // TODO: An alternative choice here would be to have a "magic" IID that + // we pass into `queryInterface` that returns the `ComponentType` directly + // (without even `addRef`-ing it). + // + ComPtr<slang::IComponentType> componentType; + inComponentType->queryInterface(SLANG_IID_PPV_ARGS(componentType.writeRef())); + return static_cast<ComponentType*>(componentType.get()); +} + +ISlangUnknown* ComponentType::getInterface(Guid const& guid) +{ + if (guid == ISlangUnknown::getTypeGuid() || guid == slang::IComponentType::getTypeGuid()) + { + return static_cast<slang::IComponentType*>(this); + } + if (guid == IModulePrecompileService_Experimental::getTypeGuid()) + return static_cast<slang::IModulePrecompileService_Experimental*>(this); + if (guid == IComponentType2::getTypeGuid()) + return static_cast<slang::IComponentType2*>(this); + return nullptr; +} + +SLANG_NO_THROW slang::ISession* SLANG_MCALL ComponentType::getSession() +{ + return m_linkage; +} + +SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL +ComponentType::getLayout(Int targetIndex, slang::IBlob** outDiagnostics) +{ + auto linkage = getLinkage(); + if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) + return nullptr; + auto target = linkage->targets[targetIndex]; + + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + auto programLayout = getTargetProgram(target)->getOrCreateLayout(&sink); + sink.getBlobIfNeeded(outDiagnostics); + + return asExternal(programLayout); +} + +static ICastable* _findDiagnosticRepresentation(IArtifact* artifact) +{ + if (auto rep = findAssociatedRepresentation<IArtifactDiagnostics>(artifact)) + { + return rep; + } + + for (auto associated : artifact->getAssociated()) + { + if (isDerivedFrom(associated->getDesc().payload, ArtifactPayload::Diagnostics)) + { + return associated; + } + } + return nullptr; +} + +static IArtifact* _findObfuscatedSourceMap(IArtifact* artifact) +{ + // If we find any obfuscated source maps, we are done + for (auto associated : artifact->getAssociated()) + { + const auto desc = associated->getDesc(); + + if (isDerivedFrom(desc.payload, ArtifactPayload::SourceMap) && + isDerivedFrom(desc.style, ArtifactStyle::Obfuscated)) + { + return associated; + } + } + return nullptr; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getResultAsFileSystem( + SlangInt entryPointIndex, + Int targetIndex, + ISlangMutableFileSystem** outFileSystem) +{ + ComPtr<ISlangBlob> diagnostics; + ComPtr<ISlangBlob> code; + + SLANG_RETURN_ON_FAIL( + getEntryPointCode(entryPointIndex, targetIndex, diagnostics.writeRef(), code.writeRef())); + + auto linkage = getLinkage(); + + auto target = linkage->targets[targetIndex]; + + auto targetProgram = getTargetProgram(target); + + IArtifact* artifact = targetProgram->getExistingEntryPointResult(entryPointIndex); + + // Add diagnostics id needs be... + if (diagnostics && !_findDiagnosticRepresentation(artifact)) + { + // Add as an associated + + auto diagnosticsArtifact = Artifact::create( + ArtifactDesc::make(Artifact::Kind::HumanText, ArtifactPayload::Diagnostics)); + diagnosticsArtifact->addRepresentationUnknown(diagnostics); + + artifact->addAssociated(diagnosticsArtifact); + + SLANG_ASSERT(diagnosticsArtifact == _findDiagnosticRepresentation(artifact)); + } + + // Add obfuscated source maps + if (!_findObfuscatedSourceMap(artifact)) + { + List<IRModule*> irModules; + enumerateIRModules([&](IRModule* irModule) -> void { irModules.add(irModule); }); + + for (auto irModule : irModules) + { + if (auto obfuscatedSourceMap = irModule->getObfuscatedSourceMap()) + { + auto artifactDesc = ArtifactDesc::make( + ArtifactKind::Json, + ArtifactPayload::SourceMap, + ArtifactStyle::Obfuscated); + + // Create the source map artifact + auto sourceMapArtifact = Artifact::create( + artifactDesc, + obfuscatedSourceMap->get().m_file.getUnownedSlice()); + + sourceMapArtifact->addRepresentation(obfuscatedSourceMap); + + // associate with the artifact + artifact->addAssociated(sourceMapArtifact); + } + } + } + + // Turn into a file system and return + ComPtr<ISlangMutableFileSystem> fileSystem(new MemoryFileSystem); + + // Filter the containerArtifact into things that can be written + ComPtr<IArtifact> writeArtifact; + SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::filter(artifact, writeArtifact)); + SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::writeContainer(writeArtifact, "", fileSystem)); + + *outFileSystem = fileSystem.detach(); + + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCode( + SlangInt entryPointIndex, + Int targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) +{ + auto linkage = getLinkage(); + if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) + return SLANG_E_INVALID_ARG; + auto target = linkage->targets[targetIndex]; + + auto targetProgram = getTargetProgram(target); + + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + IArtifact* artifact = targetProgram->getOrCreateEntryPointResult(entryPointIndex, &sink); + sink.getBlobIfNeeded(outDiagnostics); + + if (artifact == nullptr) + return SLANG_FAIL; + + return artifact->loadBlob(ArtifactKeep::Yes, outCode); +} + +SLANG_NO_THROW void SLANG_MCALL ComponentType::getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) +{ + DigestBuilder<SHA1> builder; + + // A note on enums that may be hashed in as part of the following two function calls: + // + // While enums are not guaranteed to be encoded the same way across all versions of + // the compiler, part of hashing the linkage is hashing in the compiler version. + // Consequently, any encoding differences as a result of different compiler versions + // will already be reflected in the resulting hash. + getLinkage()->buildHash(builder, targetIndex); + + buildHash(builder); + + // Add the name and name override for the specified entry point to the hash. + auto entryPoint = getEntryPoint(entryPointIndex); + if (entryPoint) + { + auto entryPointName = entryPoint->getName()->text; + builder.append(entryPointName); + auto entryPointMangledName = getEntryPointMangledName(entryPointIndex); + builder.append(entryPointMangledName); + auto entryPointNameOverride = getEntryPointNameOverride(entryPointIndex); + builder.append(entryPointNameOverride); + } + + auto hash = builder.finalize().toBlob(); + *outHash = hash.detach(); +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) +{ + auto linkage = getLinkage(); + if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) + return SLANG_E_INVALID_ARG; + auto target = linkage->targets[targetIndex]; + + auto targetProgram = getTargetProgram(target); + + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + IArtifact* artifact = targetProgram->getOrCreateEntryPointResult(entryPointIndex, &sink); + sink.getBlobIfNeeded(outDiagnostics); + + if (artifact == nullptr) + return SLANG_FAIL; + + return artifact->loadSharedLibrary(ArtifactKeep::Yes, outSharedLibrary); +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointMetadata( + SlangInt entryPointIndex, + Int targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) +{ + auto linkage = getLinkage(); + if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) + return SLANG_E_INVALID_ARG; + auto target = linkage->targets[targetIndex]; + + auto targetProgram = getTargetProgram(target); + + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + IArtifact* artifact = targetProgram->getOrCreateEntryPointResult(entryPointIndex, &sink); + sink.getBlobIfNeeded(outDiagnostics); + + if (artifact == nullptr) + return SLANG_E_NOT_AVAILABLE; + + auto metadata = findAssociatedRepresentation<IArtifactPostEmitMetadata>(artifact); + if (!metadata) + return SLANG_E_NOT_AVAILABLE; + + *outMetadata = static_cast<slang::IMetadata*>(metadata); + (*outMetadata)->addRef(); + return SLANG_OK; +} + +RefPtr<ComponentType> ComponentType::specialize( + SpecializationArg const* inSpecializationArgs, + SlangInt specializationArgCount, + DiagnosticSink* sink) +{ + if (specializationArgCount == 0) + { + return this; + } + + List<SpecializationArg> specializationArgs; + specializationArgs.addRange(inSpecializationArgs, specializationArgCount); + + // We next need to validate that the specialization arguments + // make sense, and also expand them to include any derived data + // (e.g., interface conformance witnesses) that doesn't get + // passed explicitly through the API interface. + // + RefPtr<SpecializationInfo> specializationInfo = + _validateSpecializationArgs(specializationArgs.getBuffer(), specializationArgCount, sink); + + return new SpecializedComponentType(this, specializationInfo, specializationArgs, sink); +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) +{ + DiagnosticSink sink(getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); + + // First let's check if the number of arguments given matches + // the number of parameters that are present on this component type. + // + auto specializationParamCount = getSpecializationParamCount(); + if (specializationArgCount != specializationParamCount) + { + sink.diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + specializationParamCount, + specializationArgCount); + sink.getBlobIfNeeded(outDiagnostics); + return SLANG_FAIL; + } + + List<SpecializationArg> expandedArgs; + for (Int aa = 0; aa < specializationArgCount; ++aa) + { + auto apiArg = specializationArgs[aa]; + + SpecializationArg expandedArg; + switch (apiArg.kind) + { + case slang::SpecializationArg::Kind::Type: + expandedArg.val = asInternal(apiArg.type); + break; + + default: + sink.getBlobIfNeeded(outDiagnostics); + return SLANG_FAIL; + } + expandedArgs.add(expandedArg); + } + + auto specializedComponentType = + specialize(expandedArgs.getBuffer(), expandedArgs.getCount(), &sink); + + sink.getBlobIfNeeded(outDiagnostics); + + *outSpecializedComponentType = specializedComponentType.detach(); + + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +ComponentType::renameEntryPoint(const char* newName, IComponentType** outEntryPoint) +{ + RefPtr<RenamedEntryPointComponentType> result = + new RenamedEntryPointComponentType(this, newName); + *outEntryPoint = result.detach(); + return SLANG_OK; +} + +RefPtr<ComponentType> fillRequirements(ComponentType* inComponentType); + +SLANG_NO_THROW SlangResult SLANG_MCALL +ComponentType::link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) +{ + // TODO: It should be possible for `fillRequirements` to fail, + // in cases where we have a dependency that can't be automatically + // resolved. + // + SLANG_UNUSED(outDiagnostics); + + DiagnosticSink sink(getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); + + try + { + auto linked = fillRequirements(this); + if (!linked) + return SLANG_FAIL; + + *outLinkedComponentType = ComPtr<slang::IComponentType>(linked).detach(); + return SLANG_OK; + } + catch (const AbortCompilationException& e) + { + outputExceptionDiagnostic(e, sink, outDiagnostics); + return SLANG_FAIL; + } + catch (const Exception& e) + { + outputExceptionDiagnostic(e, sink, outDiagnostics); + return SLANG_FAIL; + } + catch (...) + { + outputExceptionDiagnostic(sink, outDiagnostics); + return SLANG_FAIL; + } +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::linkWithOptions( + slang::IComponentType** outLinkedComponentType, + uint32_t count, + slang::CompilerOptionEntry* entries, + ISlangBlob** outDiagnostics) +{ + SLANG_RETURN_ON_FAIL(link(outLinkedComponentType, outDiagnostics)); + + auto linked = *outLinkedComponentType; + + if (linked) + { + static_cast<ComponentType*>(linked)->getOptionSet().load(count, entries); + } + + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCompileResult( + SlangInt entryPointIndex, + Int targetIndex, + slang::ICompileResult** outCompileResult, + slang::IBlob** outDiagnostics) +{ + auto linkage = getLinkage(); + if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) + return SLANG_E_INVALID_ARG; + auto target = linkage->targets[targetIndex]; + + auto targetProgram = getTargetProgram(target); + + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + IArtifact* artifact = targetProgram->getOrCreateEntryPointResult(entryPointIndex, &sink); + sink.getBlobIfNeeded(outDiagnostics); + if (artifact == nullptr) + return SLANG_E_NOT_AVAILABLE; + + *outCompileResult = static_cast<slang::ICompileResult*>(artifact); + (*outCompileResult)->addRef(); + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetCompileResult( + Int targetIndex, + slang::ICompileResult** outCompileResult, + slang::IBlob** outDiagnostics) +{ + IArtifact* artifact = getTargetArtifact(targetIndex, outDiagnostics); + if (artifact == nullptr) + return SLANG_E_NOT_AVAILABLE; + + *outCompileResult = static_cast<slang::ICompileResult*>(artifact); + (*outCompileResult)->addRef(); + return SLANG_OK; +} + +/// Visitor used by `ComponentType::enumerateModules` +struct EnumerateModulesVisitor : ComponentTypeVisitor +{ + EnumerateModulesVisitor(ComponentType::EnumerateModulesCallback callback, void* userData) + : m_callback(callback), m_userData(userData) + { + } + + ComponentType::EnumerateModulesCallback m_callback; + void* m_userData; + + void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + { + m_callback(module, m_userData); + } + + void visitComposite( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + SLANG_UNUSED(conformance); + } +}; + + +void ComponentType::enumerateModules(EnumerateModulesCallback callback, void* userData) +{ + EnumerateModulesVisitor visitor(callback, userData); + acceptVisitor(&visitor, nullptr); +} + +/// Visitor used by `ComponentType::enumerateIRModules` +struct EnumerateIRModulesVisitor : ComponentTypeVisitor +{ + EnumerateIRModulesVisitor(ComponentType::EnumerateIRModulesCallback callback, void* userData) + : m_callback(callback), m_userData(userData) + { + } + + ComponentType::EnumerateIRModulesCallback m_callback; + void* m_userData; + + void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} + + void visitRenamedEntryPoint( + RenamedEntryPointComponentType* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + entryPoint->getBase()->acceptVisitor(this, specializationInfo); + } + + void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE + { + m_callback(module->getIRModule(), m_userData); + } + + void visitComposite( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE + { + visitChildren(composite, specializationInfo); + } + + void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE + { + visitChildren(specialized); + + m_callback(specialized->getIRModule(), m_userData); + } + + void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE + { + m_callback(conformance->getIRModule(), m_userData); + } +}; + +void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData) +{ + EnumerateIRModulesVisitor visitor(callback, userData); + acceptVisitor(&visitor, nullptr); +} + +IArtifact* ComponentType::getTargetArtifact(Int targetIndex, slang::IBlob** outDiagnostics) +{ + auto linkage = getLinkage(); + if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) + return nullptr; + ComPtr<IArtifact> artifact; + if (m_targetArtifacts.tryGetValue(targetIndex, artifact)) + { + return artifact.get(); + } + + // If the user hasn't specified any entry points, then we should + // discover all entrypoints that are defined in linked modules, and + // include all of them in the compile. + // + if (getEntryPointCount() == 0) + { + List<Module*> modules; + this->enumerateModules([&](Module* module) { modules.add(module); }); + List<RefPtr<ComponentType>> components; + components.add(this); + bool entryPointsDiscovered = false; + for (auto module : modules) + { + for (auto entryPoint : module->getEntryPoints()) + { + components.add(entryPoint); + entryPointsDiscovered = true; + } + } + + // If any entry points were discovered, then we should emit the program with entrypoints + // linked. + if (entryPointsDiscovered) + { + RefPtr<CompositeComponentType> composite = + new CompositeComponentType(linkage, components); + ComPtr<IComponentType> linkedComponentType; + SLANG_RETURN_NULL_ON_FAIL( + composite->link(linkedComponentType.writeRef(), outDiagnostics)); + auto targetArtifact = static_cast<ComponentType*>(linkedComponentType.get()) + ->getTargetArtifact(targetIndex, outDiagnostics); + if (targetArtifact) + { + m_targetArtifacts[targetIndex] = targetArtifact; + } + return targetArtifact; + } + } + + auto target = linkage->targets[targetIndex]; + auto targetProgram = getTargetProgram(target); + + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + IArtifact* targetArtifact = targetProgram->getOrCreateWholeProgramResult(&sink); + sink.getBlobIfNeeded(outDiagnostics); + m_targetArtifacts[targetIndex] = ComPtr<IArtifact>(targetArtifact); + return targetArtifact; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +ComponentType::getTargetCode(Int targetIndex, slang::IBlob** outCode, slang::IBlob** outDiagnostics) +{ + IArtifact* artifact = getTargetArtifact(targetIndex, outDiagnostics); + + if (artifact == nullptr) + return SLANG_FAIL; + + return artifact->loadBlob(ArtifactKeep::Yes, outCode); +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetMetadata( + Int targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) +{ + IArtifact* artifact = getTargetArtifact(targetIndex, outDiagnostics); + + if (artifact == nullptr) + return SLANG_FAIL; + + auto metadata = findAssociatedRepresentation<IArtifactPostEmitMetadata>(artifact); + if (!metadata) + return SLANG_E_NOT_AVAILABLE; + *outMetadata = static_cast<slang::IMetadata*>(metadata); + (*outMetadata)->addRef(); + return SLANG_OK; +} + +Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* sink) +{ + // If we've looked up this type name before, + // then we can re-use it. + // + Type* type = nullptr; + if (m_types.tryGetValue(typeStr, type)) + return type; + + + // TODO(JS): For now just used the linkages ASTBuilder to keep on scope + // + // The parseTermString uses the linkage ASTBuilder for it's parsing. + // + // It might be possible to just create a temporary ASTBuilder - the worry though is + // that the parsing sets a member variable in AST node to one of these scopes, and then + // it become a dangling pointer. So for now we go with the linkages. + auto astBuilder = getLinkage()->getASTBuilder(); + + // Otherwise, we need to start looking in + // the modules that were directly or + // indirectly referenced. + // + Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); + + auto linkage = getLinkage(); + + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + + Expr* typeExpr = linkage->parseTermString(typeStr, scope); + SharedSemanticsContext sharedSemanticsContext(linkage, nullptr, sink); + SemanticsVisitor visitor(&sharedSemanticsContext); + type = visitor.TranslateTypeNode(typeExpr); + auto typeOut = visitor.tryCoerceToProperType(TypeExp(type)); + if (typeOut.type) + type = typeOut.type; + + if (type) + { + m_types[typeStr] = type; + } + return type; +} + +Expr* ComponentType::findDeclFromString(String const& name, DiagnosticSink* sink) +{ + // If we've looked up this type name before, + // then we can re-use it. + // + Expr* result = nullptr; + if (m_decls.tryGetValue(name, result)) + return result; + + + // TODO(JS): For now just used the linkages ASTBuilder to keep on scope + // + // The parseTermString uses the linkage ASTBuilder for it's parsing. + // + // It might be possible to just create a temporary ASTBuilder - the worry though is + // that the parsing sets a member variable in AST node to one of these scopes, and then + // it become a dangling pointer. So for now we go with the linkages. + auto astBuilder = getLinkage()->getASTBuilder(); + + // Otherwise, we need to start looking in + // the modules that were directly or + // indirectly referenced. + // + Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); + + auto linkage = getLinkage(); + + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + + Expr* expr = linkage->parseTermString(name, scope); + + SemanticsContext context(linkage->getSemanticsForReflection()); + context = context.allowStaticReferenceToNonStaticMember().withSink(sink); + + SemanticsVisitor visitor(context); + + auto checkedExpr = visitor.CheckTerm(expr); + + if (as<DeclRefExpr>(checkedExpr) || as<OverloadedExpr>(checkedExpr)) + { + result = checkedExpr; + } + + m_decls[name] = result; + return result; +} + +static bool _isSimpleName(String const& name) +{ + for (char c : name) + { + if (!CharUtil::isAlphaOrDigit(c) && c != '_' && c != '$') + return false; + } + return true; +} + +Expr* ComponentType::findDeclFromStringInType( + Type* type, + String const& name, + LookupMask mask, + DiagnosticSink* sink) +{ + // Only look up in the type if it is a DeclRefType + if (!as<DeclRefType>(type)) + return nullptr; + + // TODO(JS): For now just used the linkages ASTBuilder to keep on scope + // + // The parseTermString uses the linkage ASTBuilder for it's parsing. + // + // It might be possible to just create a temporary ASTBuilder - the worry though is + // that the parsing sets a member variable in AST node to one of these scopes, and then + // it become a dangling pointer. So for now we go with the linkages. + auto astBuilder = getLinkage()->getASTBuilder(); + + // Otherwise, we need to start looking in + // the modules that were directly or + // indirectly referenced. + // + Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); + + auto linkage = getLinkage(); + + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + + Expr* expr = nullptr; + + if (_isSimpleName(name)) + { + auto varExpr = astBuilder->create<VarExpr>(); + varExpr->scope = scope; + varExpr->name = getLinkage()->getNamePool()->getName(name); + expr = varExpr; + } + else + { + expr = linkage->parseTermString(name, scope); + } + SemanticsContext context(linkage->getSemanticsForReflection()); + context = context.allowStaticReferenceToNonStaticMember().withSink(sink); + + SemanticsVisitor visitor(context); + + GenericAppExpr* genericOuterExpr = nullptr; + if (as<GenericAppExpr>(expr)) + { + // Unwrap the generic application, and re-wrap it around the static-member expr + genericOuterExpr = as<GenericAppExpr>(expr); + expr = genericOuterExpr->functionExpr; + } + + if (!as<VarExpr>(expr)) + return nullptr; + + auto rs = astBuilder->create<StaticMemberExpr>(); + auto typeExpr = astBuilder->create<SharedTypeExpr>(); + auto typetype = astBuilder->getOrCreate<TypeType>(type); + typeExpr->type = typetype; + rs->baseExpression = typeExpr; + rs->name = as<VarExpr>(expr)->name; + + expr = rs; + + // If we have a generic-app expression, re-wrap the static-member expr + if (genericOuterExpr) + { + genericOuterExpr->functionExpr = expr; + expr = genericOuterExpr; + } + + auto checkedTerm = visitor.CheckTerm(expr); + + // Check if checkedTerm is overloaded functions and avoid resolving if so + // to preserve all function overloads with different signatures + Expr* resolvedTerm = checkedTerm; + if (auto overloadedExpr = as<OverloadedExpr>(checkedTerm)) + { + // Check if all candidates are function references + bool allAreFunctions = true; + for (auto item : overloadedExpr->lookupResult2.items) + { + if (!as<FunctionDeclBase>(item.declRef.getDecl())) + { + allAreFunctions = false; + break; + } + } + + // If not all are functions, resolve the overload as usual + if (!allAreFunctions) + { + resolvedTerm = visitor.maybeResolveOverloadedExpr(checkedTerm, mask, sink); + } + } + else + { + // Not overloaded, resolve as usual + resolvedTerm = visitor.maybeResolveOverloadedExpr(checkedTerm, mask, sink); + } + + if (auto overloadedExpr = as<OverloadedExpr>(resolvedTerm)) + { + return overloadedExpr; + } + if (auto declRefExpr = as<DeclRefExpr>(resolvedTerm)) + { + return declRefExpr; + } + + return nullptr; +} + +bool ComponentType::isSubType(Type* subType, Type* superType) +{ + SemanticsContext context(getLinkage()->getSemanticsForReflection()); + SemanticsVisitor visitor(context); + + return (visitor.isSubtype(subType, superType, IsSubTypeOptions::None) != nullptr); +} + +static void collectExportedConstantInContainer( + Dictionary<String, IntVal*>& dict, + ASTBuilder* builder, + ContainerDecl* containerDecl) +{ + for (auto varMember : containerDecl->getDirectMemberDeclsOfType<VarDeclBase>()) + { + if (!varMember->val) + continue; + bool isExported = false; + bool isConst = false; + bool isExtern = false; + for (auto modifier : varMember->modifiers) + { + if (as<HLSLExportModifier>(modifier)) + isExported = true; + if (as<ExternAttribute>(modifier) || as<ExternModifier>(modifier)) + { + isExtern = true; + isExported = true; + } + if (as<ConstModifier>(modifier)) + isConst = true; + } + if (isExported && isConst) + { + auto mangledName = getMangledName(builder, varMember); + if (isExtern && dict.containsKey(mangledName)) + continue; + dict[mangledName] = varMember->val; + } + } + + for (auto member : containerDecl->getDirectMemberDecls()) + { + if (as<NamespaceDecl>(member) || as<FileDecl>(member)) + { + collectExportedConstantInContainer(dict, builder, (ContainerDecl*)member); + } + } +} + +Dictionary<String, IntVal*>& ComponentType::getMangledNameToIntValMap() +{ + if (m_mapMangledNameToIntVal) + { + return *m_mapMangledNameToIntVal; + } + m_mapMangledNameToIntVal = std::make_unique<Dictionary<String, IntVal*>>(); + auto astBuilder = getLinkage()->getASTBuilder(); + SLANG_AST_BUILDER_RAII(astBuilder); + Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); + for (; scope; scope = scope->nextSibling) + { + if (scope->containerDecl) + collectExportedConstantInContainer( + *m_mapMangledNameToIntVal, + astBuilder, + scope->containerDecl); + } + return *m_mapMangledNameToIntVal; +} + +ConstantIntVal* ComponentType::tryFoldIntVal(IntVal* intVal) +{ + auto astBuilder = getLinkage()->getASTBuilder(); + SLANG_AST_BUILDER_RAII(astBuilder); + return as<ConstantIntVal>(intVal->linkTimeResolve(getMangledNameToIntValMap())); +} + +TargetProgram* ComponentType::getTargetProgram(TargetRequest* target) +{ + RefPtr<TargetProgram> targetProgram; + if (!m_targetPrograms.tryGetValue(target, targetProgram)) + { + targetProgram = new TargetProgram(this, target); + m_targetPrograms[target] = targetProgram; + } + return targetProgram; +} + +} // namespace Slang diff --git a/source/slang/slang-linkable.h b/source/slang/slang-linkable.h new file mode 100644 index 000000000..e900fd275 --- /dev/null +++ b/source/slang/slang-linkable.h @@ -0,0 +1,430 @@ +// slang-linkable.h +#pragma once + +// +// This file defines the `ComponentType` class, which +// provides the root of the hierarchy for classes +// that represent units of linkable code. +// +// The most obvious case of linkable code is a single +// `Module` produced by invoking the Slang front-end +// on source code (or by loading a previously compiled +// `.slang-module` file). +// + +#include "../compiler-core/slang-artifact.h" +#include "slang-ast-base.h" +#include "slang-compiler-fwd.h" +#include "slang-compiler-options.h" + +#include <slang-com-helper.h> +#include <slang.h> + +namespace Slang +{ +class Linkage; + +class EntryPoint; + +class ComponentType; +class ComponentTypeVisitor; + +/// Information collected about global or entry-point shader parameters +struct ShaderParamInfo +{ + DeclRef<VarDeclBase> paramDeclRef; + Int firstSpecializationParamIndex = 0; + Int specializationParamCount = 0; +}; + +/// Tracks an ordered list of modules that something depends on. +/// TODO: Shader caching currently relies on this being in well defined order. +struct ModuleDependencyList +{ +public: + /// Get the list of modules that are depended on. + List<Module*> const& getModuleList() { return m_moduleList; } + + /// Add a module and everything it depends on to the list. + void addDependency(Module* module); + + /// Add a module to the list, but not the modules it depends on. + void addLeafDependency(Module* module); + +private: + void _addDependency(Module* module); + + List<Module*> m_moduleList; + HashSet<Module*> m_moduleSet; +}; + +/// Tracks an unordered list of source files that something depends on +/// TODO: Shader caching currently relies on this being in well defined order. +struct FileDependencyList +{ +public: + /// Get the list of files that are depended on. + List<SourceFile*> const& getFileList() { return m_fileList; } + + /// Add a file to the list, if it is not already present + void addDependency(SourceFile* sourceFile); + + /// Add all of the paths that `module` depends on to the list + void addDependency(Module* module); + + void clear() + { + m_fileList.clear(); + m_fileSet.clear(); + } + +private: + // TODO: We are using a `HashSet` here to deduplicate + // the paths so that we don't return the same path + // multiple times from `getFilePathList`, but because + // order isn't important, we could potentially do better + // in terms of memory (at some cost in performance) by + // just sorting the `m_fileList` every once in + // a while and then deduplicating. + + List<SourceFile*> m_fileList; + HashSet<SourceFile*> m_fileSet; +}; + +/// Base class for "component types" that represent the pieces a final +/// shader program gets linked together from. +/// +class ComponentType : public RefObject, + public slang::IComponentType, + public slang::IComponentType2, + public slang::IModulePrecompileService_Experimental +{ +public: + // + // ISlangUnknown interface + // + + SLANG_REF_OBJECT_IUNKNOWN_ALL; + ISlangUnknown* getInterface(Guid const& guid); + + // + // slang::IComponentType interface + // + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE; + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + + IArtifact* getTargetArtifact(SlangInt targetIndex, slang::IBlob** outDiagnostics); + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + + SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE; + + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL + renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + + /// ComponentType is the only class inheriting from IComponentType that provides a + /// meaningful implementation for this function. All others should forward these and + /// implement `buildHash`. + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) SLANG_OVERRIDE; + + SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + slang::IComponentType** outLinkedComponentType, + uint32_t count, + slang::CompilerOptionEntry* entries, + ISlangBlob** outDiagnostics) override; + + // + // slang::IComponentType2 interface + // + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCompileResult( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::ICompileResult** outCompileResult, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCompileResult( + SlangInt targetIndex, + slang::ICompileResult** outCompileResult, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + + // + // slang::IModulePrecompileService interface + // + SLANG_NO_THROW SlangResult SLANG_MCALL + precompileForTarget(SlangCompileTarget target, slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + + SLANG_NO_THROW SlangResult SLANG_MCALL getPrecompiledTargetCode( + SlangCompileTarget target, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + + SLANG_NO_THROW SlangInt SLANG_MCALL getModuleDependencyCount() SLANG_OVERRIDE; + + SLANG_NO_THROW SlangResult SLANG_MCALL getModuleDependency( + SlangInt dependencyIndex, + slang::IModule** outModule, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + + CompilerOptionSet& getOptionSet() { return m_optionSet; } + + /// Get the linkage (aka "session" in the public API) for this component type. + Linkage* getLinkage() { return m_linkage; } + + /// Get the target-specific version of this program for the given `target`. + /// + /// The `target` must be a target on the `Linkage` that was used to create this program. + TargetProgram* getTargetProgram(TargetRequest* target); + + /// Update the hash builder with the dependencies for this component type. + virtual void buildHash(DigestBuilder<SHA1>& builder) = 0; + + /// Get the number of entry points linked into this component type. + virtual Index getEntryPointCount() = 0; + + /// Get one of the entry points linked into this component type. + virtual RefPtr<EntryPoint> getEntryPoint(Index index) = 0; + + /// Get the mangled name of one of the entry points linked into this component type. + virtual String getEntryPointMangledName(Index index) = 0; + + /// Get the name override of one of the entry points linked into this component type. + virtual String getEntryPointNameOverride(Index index) = 0; + + /// Get the number of global shader parameters linked into this component type. + virtual Index getShaderParamCount() = 0; + + /// Get one of the global shader parametesr linked into this component type. + virtual ShaderParamInfo getShaderParam(Index index) = 0; + + /// Get the specialization parameter at `index`. + virtual SpecializationParam const& getSpecializationParam(Index index) = 0; + + /// Get the number of "requirements" that this component type has. + /// + /// A requirement represents another component type that this component + /// needs in order to function correctly. For example, the dependency + /// of one module on another module that it `import`s is represented + /// as a requirement, as is the dependency of an entry point on the + /// module that defines it. + /// + virtual Index getRequirementCount() = 0; + + /// Get the requirement at `index`. + virtual RefPtr<ComponentType> getRequirement(Index index) = 0; + + /// Parse a type from a string, in the context of this component type. + /// + /// Any names in the string will be resolved using the modules + /// referenced by the program. + /// + /// On an error, returns null and reports diagnostic messages + /// to the provided `sink`. + /// + /// TODO: This function shouldn't be on the base class, since + /// it only really makes sense on `Module`. + /// + Type* getTypeFromString(String const& typeStr, DiagnosticSink* sink); + + Expr* findDeclFromString(String const& name, DiagnosticSink* sink); + + Expr* findDeclFromStringInType( + Type* type, + String const& name, + LookupMask mask, + DiagnosticSink* sink); + + bool isSubType(Type* subType, Type* superType); + + Dictionary<String, IntVal*>& getMangledNameToIntValMap(); + ConstantIntVal* tryFoldIntVal(IntVal* intVal); + + /// Get a list of modules that this component type depends on. + /// + virtual List<Module*> const& getModuleDependencies() = 0; + + /// Get the full list of source files this component type depends on. + /// + virtual List<SourceFile*> const& getFileDependencies() = 0; + + /// Callback for use with `enumerateIRModules` + typedef void (*EnumerateIRModulesCallback)(IRModule* irModule, void* userData); + + /// Invoke `callback` on all the IR modules that are (transitively) linked into this component + /// type. + void enumerateIRModules(EnumerateIRModulesCallback callback, void* userData); + + /// Invoke `callback` on all the IR modules that are (transitively) linked into this component + /// type. + template<typename F> + void enumerateIRModules(F const& callback) + { + struct Helper + { + static void helper(IRModule* irModule, void* userData) { (*(F*)userData)(irModule); } + }; + enumerateIRModules(&Helper::helper, (void*)&callback); + } + + /// Callback for use with `enumerateModules` + typedef void (*EnumerateModulesCallback)(Module* module, void* userData); + + /// Invoke `callback` on all the modules that are (transitively) linked into this component + /// type. + void enumerateModules(EnumerateModulesCallback callback, void* userData); + + /// Invoke `callback` on all the modules that are (transitively) linked into this component + /// type. + template<typename F> + void enumerateModules(F const& callback) + { + struct Helper + { + static void helper(Module* module, void* userData) { (*(F*)userData)(module); } + }; + enumerateModules(&Helper::helper, (void*)&callback); + } + + /// Side-band information generated when specializing this component type. + /// + /// Difference subclasses of `ComponentType` are expected to create their + /// own subclass of `SpecializationInfo` as the output of `_validateSpecializationArgs`. + /// Later, whenever we want to use a specialized component type we will + /// also have the `SpecializationInfo` available and will expect it to + /// have the correct (subclass-specific) type. + /// + class SpecializationInfo : public RefObject + { + }; + + /// Validate the given specialization `args` and compute any side-band specialization info. + /// + /// Any errors will be reported to `sink`, which can thus be used to test + /// if the operation was successful. + /// + /// A null return value is allowed, since not all subclasses require + /// custom side-band specialization information. + /// + /// This function is an implementation detail of `specialize()`. + /// + virtual RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) = 0; + + /// Validate the given specialization `args` and compute any side-band specialization info. + /// + /// Any errors will be reported to `sink`, which can thus be used to test + /// if the operation was successful. + /// + /// A null return value is allowed, since not all subclasses require + /// custom side-band specialization information. + /// + /// This function is an implementation detail of `specialize()`. + /// + RefPtr<SpecializationInfo> _validateSpecializationArgs( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) + { + if (argCount == 0) + return nullptr; + return _validateSpecializationArgsImpl(args, argCount, sink); + } + + /// Specialize this component type given `specializationArgs` + /// + /// Any diagnostics will be reported to `sink`, which can be used + /// to determine if the operation was successful. It is allowed + /// for this operation to have a non-null return even when an + /// error is ecnountered. + /// + RefPtr<ComponentType> specialize( + SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + DiagnosticSink* sink); + + /// Invoke `visitor` on this component type, using the appropriate dynamic type. + /// + /// This function implements the "visitor pattern" for `ComponentType`. + /// + /// If the `specializationInfo` argument is non-null, it must be specialization + /// information generated for this specific component type by `_validateSpecializationArgs`. + /// In that case, appropriately-typed specialization information will be passed + /// when invoking the `visitor`. + /// + virtual void acceptVisitor( + ComponentTypeVisitor* visitor, + SpecializationInfo* specializationInfo) = 0; + + /// Create a scope suitable for looking up names or parsing specialization arguments. + /// + /// This facility is only needed to support legacy APIs for string-based lookup + /// and parsing via Slang reflection, and is not recommended for future APIs to use. + /// + Scope* _getOrCreateScopeForLegacyLookup(ASTBuilder* astBuilder); + +protected: + ComponentType(Linkage* linkage); + +protected: + Linkage* m_linkage; + + CompilerOptionSet m_optionSet; + + // Cache of target-specific programs for each target. + Dictionary<TargetRequest*, RefPtr<TargetProgram>> m_targetPrograms; + + // Any types looked up dynamically using `getTypeFromString` + // + // TODO: Remove this. Type lookup should only be supported on `Module`s. + // + Dictionary<String, Type*> m_types; + + // Any decls looked up dynamically using `findDeclFromString`. + Dictionary<String, Expr*> m_decls; + + Scope* m_lookupScope = nullptr; + std::unique_ptr<Dictionary<String, IntVal*>> m_mapMangledNameToIntVal; + + Dictionary<Int, ComPtr<IArtifact>> m_targetArtifacts; +}; + +} // namespace Slang diff --git a/source/slang/slang-module.cpp b/source/slang/slang-module.cpp new file mode 100644 index 000000000..06967de7d --- /dev/null +++ b/source/slang/slang-module.cpp @@ -0,0 +1,420 @@ +// slang-module.cpp +#include "slang-module.h" + +#include "slang-check-impl.h" +#include "slang-compiler.h" +#include "slang-mangle.h" +#include "slang-serialize-container.h" + +namespace Slang +{ + +// +// Module +// + +Module::Module(Linkage* linkage, ASTBuilder* astBuilder) + : ComponentType(linkage), m_mangledExportPool(StringSlicePool::Style::Empty) +{ + if (astBuilder) + { + m_astBuilder = astBuilder; + } + else + { + m_astBuilder = linkage->getASTBuilder(); + } + getOptionSet() = linkage->m_optionSet; + addModuleDependency(this); +} + +ISlangUnknown* Module::getInterface(const Guid& guid) +{ + if (guid == IModule::getTypeGuid()) + return asExternal(this); + if (guid == IModulePrecompileService_Experimental::getTypeGuid()) + return static_cast<slang::IModulePrecompileService_Experimental*>(this); + return Super::getInterface(guid); +} + +void Module::buildHash(DigestBuilder<SHA1>& builder) +{ + builder.append(computeDigest()); +} + +slang::DeclReflection* Module::getModuleReflection() +{ + return (slang::DeclReflection*)m_moduleDecl; +} + +SHA1::Digest Module::computeDigest() +{ + if (m_digest == SHA1::Digest()) + { + DigestBuilder<SHA1> digestBuilder; + auto version = String(getBuildTagString()); + digestBuilder.append(version); + getOptionSet().buildHash(digestBuilder); + + auto fileDependencies = getFileDependencies(); + + for (auto file : fileDependencies) + { + digestBuilder.append(file->getDigest()); + } + m_digest = digestBuilder.finalize(); + } + return m_digest; +} + +void Module::addModuleDependency(Module* module) +{ + m_moduleDependencyList.addDependency(module); + m_fileDependencyList.addDependency(module); +} + +void Module::addFileDependency(SourceFile* sourceFile) +{ + m_fileDependencyList.addDependency(sourceFile); +} + +void Module::setModuleDecl(ModuleDecl* moduleDecl) +{ + m_moduleDecl = moduleDecl; + moduleDecl->module = this; +} + +void Module::setName(String name) +{ + m_name = getLinkage()->getNamePool()->getName(name); +} + + +RefPtr<EntryPoint> Module::findEntryPointByName(UnownedStringSlice const& name) +{ + for (auto entryPoint : m_entryPoints) + { + if (entryPoint->getName()->text.getUnownedSlice() == name) + return entryPoint; + } + + return nullptr; +} + +RefPtr<EntryPoint> Module::findAndCheckEntryPoint( + UnownedStringSlice const& name, + SlangStage stage, + ISlangBlob** outDiagnostics) +{ + // If there is already an entrypoint marked with the [shader] attribute, + // we should just return that. + // + if (auto existingEntryPoint = findEntryPointByName(name)) + return existingEntryPoint; + + SLANG_AST_BUILDER_RAII(m_astBuilder); + + // If the function hasn't been marked as [shader], then it won't be discovered + // by findEntryPointByName. We need to route this to the `findAndValidateEntryPoint` + // function. To do that we need to setup a FrontEndCompileRequest and a + // FrontEndEntryPointRequest. + // + DiagnosticSink sink(getLinkage()->getSourceManager(), DiagnosticSink::SourceLocationLexer()); + FrontEndCompileRequest frontEndRequest(getLinkage(), StdWriters::getSingleton(), &sink); + RefPtr<TranslationUnitRequest> tuRequest = new TranslationUnitRequest(&frontEndRequest); + tuRequest->module = this; + tuRequest->moduleName = m_name; + frontEndRequest.translationUnits.add(tuRequest); + FrontEndEntryPointRequest entryPointRequest( + &frontEndRequest, + 0, + getLinkage()->getNamePool()->getName(name), + Profile((Stage)stage)); + auto result = findAndValidateEntryPoint(&entryPointRequest); + if (outDiagnostics) + { + sink.getBlobIfNeeded(outDiagnostics); + } + return result; +} + +void Module::_addEntryPoint(EntryPoint* entryPoint) +{ + m_entryPoints.add(entryPoint); +} + +static bool _canExportDeclSymbol(ASTNodeType type) +{ + switch (type) + { + case ASTNodeType::EmptyDecl: + { + return false; + } + default: + break; + } + + return true; +} + +static bool _canRecurseExportSymbol(Decl* decl) +{ + if (as<FunctionDeclBase>(decl) || as<ScopeDecl>(decl)) + { + return false; + } + return true; +} + +void Module::_processFindDeclsExportSymbolsRec(Decl* decl) +{ + if (_canExportDeclSymbol(decl->astNodeType)) + { + // It's a reference to a declaration in another module, so first get the symbol name. + String mangledName = getMangledName(getCurrentASTBuilder(), decl); + + Index index = Index(m_mangledExportPool.add(mangledName)); + + // TODO(JS): It appears that more than one entity might have the same mangled name. + // So for now we ignore and just take the first one. + if (index == m_mangledExportSymbols.getCount()) + { + m_mangledExportSymbols.add(decl); + } + } + + if (!_canRecurseExportSymbol(decl)) + { + // We don't need to recurse any further into this + return; + } + + // If it's a container process it's children + if (auto containerDecl = as<ContainerDecl>(decl)) + { + for (auto child : containerDecl->getDirectMemberDecls()) + { + _processFindDeclsExportSymbolsRec(child); + } + } + + // GenericDecl is also a container, so do subsequent test + if (auto genericDecl = as<GenericDecl>(decl)) + { + _processFindDeclsExportSymbolsRec(genericDecl->inner); + } +} + +Decl* Module::findExportedDeclByMangledName(const UnownedStringSlice& mangledName) +{ + // If this module is a serialized module that is being + // deserialized on-demand, then we want to use the + // mangled name mapping that was baked into the serialized + // data, rather than attempt to enumerate all of the declarations + // in the module (as would be done if we proceeded to call + // `ensureExportLookupAcceleratorBuilt()`). + // + if (this->m_moduleDecl->isUsingOnDemandDeserializationForExports()) + { + return m_moduleDecl->_findSerializedDeclByMangledExportName(mangledName); + } + + ensureExportLookupAcceleratorBuilt(); + + const Index index = m_mangledExportPool.findIndex(mangledName); + return (index >= 0) ? m_mangledExportSymbols[index] : nullptr; +} + +void Module::ensureExportLookupAcceleratorBuilt() +{ + // Will be non zero if has been previously attempted + if (m_mangledExportSymbols.getCount() == 0) + { + // Build up the exported mangled name list + _processFindDeclsExportSymbolsRec(getModuleDecl()); + + // If nothing found, mark that we have tried looking by making + // m_mangledExportSymbols.getCount() != 0 + if (m_mangledExportSymbols.getCount() == 0) + { + m_mangledExportSymbols.add(nullptr); + } + } +} + +Count Module::getExportedDeclCount() +{ + ensureExportLookupAcceleratorBuilt(); + + return m_mangledExportPool.getSlicesCount(); +} + +Decl* Module::getExportedDecl(Index index) +{ + ensureExportLookupAcceleratorBuilt(); + return m_mangledExportSymbols[index]; +} + +UnownedStringSlice Module::getExportedDeclMangledName(Index index) +{ + ensureExportLookupAcceleratorBuilt(); + return m_mangledExportPool.getSlices()[index]; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL Module::serialize(ISlangBlob** outSerializedBlob) +{ + SerialContainerUtil::WriteOptions writeOptions; + OwnedMemoryStream memoryStream(FileAccess::Write); + SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(this, writeOptions, &memoryStream)); + *outSerializedBlob = RawBlob::create( + memoryStream.getContents().getBuffer(), + (size_t)memoryStream.getContents().getCount()) + .detach(); + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL Module::writeToFile(char const* fileName) +{ + SerialContainerUtil::WriteOptions writeOptions; + FileStream fileStream; + SLANG_RETURN_ON_FAIL(fileStream.init(fileName, FileMode::Create)); + return SerialContainerUtil::write(this, writeOptions, &fileStream); +} + +SLANG_NO_THROW const char* SLANG_MCALL Module::getName() +{ + if (m_name) + return m_name->text.getBuffer(); + return nullptr; +} + +SLANG_NO_THROW const char* SLANG_MCALL Module::getFilePath() +{ + if (m_pathInfo.hasFoundPath()) + return m_pathInfo.foundPath.getBuffer(); + return nullptr; +} + +SLANG_NO_THROW const char* SLANG_MCALL Module::getUniqueIdentity() +{ + if (m_pathInfo.hasUniqueIdentity()) + return m_pathInfo.getMostUniqueIdentity().getBuffer(); + return nullptr; +} + +SLANG_NO_THROW SlangInt32 SLANG_MCALL Module::getDependencyFileCount() +{ + return (SlangInt32)getFileDependencies().getCount(); +} + +SLANG_NO_THROW char const* SLANG_MCALL Module::getDependencyFilePath(SlangInt32 index) +{ + SourceFile* sourceFile = getFileDependencies()[index]; + return sourceFile->getPathInfo().hasFoundPath() + ? sourceFile->getPathInfo().getMostUniqueIdentity().getBuffer() + : nullptr; +} + +void Module::_discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets) +{ + if (m_entryPoints.getCount() > 0) + return; + _discoverEntryPointsImpl(m_moduleDecl, sink, targets); +} +void Module::_discoverEntryPointsImpl( + ContainerDecl* containerDecl, + DiagnosticSink* sink, + const List<RefPtr<TargetRequest>>& targets) +{ + for (auto globalDecl : containerDecl->getDirectMemberDecls()) + { + auto maybeFuncDecl = globalDecl; + if (auto genericDecl = as<GenericDecl>(maybeFuncDecl)) + { + maybeFuncDecl = genericDecl->inner; + } + + if (as<NamespaceDeclBase>(globalDecl) || as<FileDecl>(globalDecl) || + as<StructDecl>(globalDecl)) + { + _discoverEntryPointsImpl(as<ContainerDecl>(globalDecl), sink, targets); + continue; + } + + auto funcDecl = as<FuncDecl>(maybeFuncDecl); + if (!funcDecl) + continue; + + Profile profile; + bool resolvedStageOfProfileWithEntryPoint = resolveStageOfProfileWithEntryPoint( + profile, + getLinkage()->m_optionSet, + targets, + funcDecl, + sink); + if (!resolvedStageOfProfileWithEntryPoint) + { + // If there isn't a [shader] attribute, look for a [numthreads] attribute + // since that implicitly means a compute shader. We'll not do this when compiling for + // CUDA/Torch since [numthreads] attributes are utilized differently for those targets. + // + + bool allTargetsCUDARelated = true; + for (auto target : targets) + { + if (!isCUDATarget(target) && + target->getTarget() != CodeGenTarget::PyTorchCppBinding) + { + allTargetsCUDARelated = false; + break; + } + } + + if (allTargetsCUDARelated && targets.getCount() > 0) + continue; + + bool canDetermineStage = false; + for (auto modifier : funcDecl->modifiers) + { + if (as<NumThreadsAttribute>(modifier)) + { + if (funcDecl->findModifier<OutputTopologyAttribute>()) + profile.setStage(Stage::Mesh); + else + profile.setStage(Stage::Compute); + canDetermineStage = true; + break; + } + else if (as<PatchConstantFuncAttribute>(modifier)) + { + profile.setStage(Stage::Hull); + canDetermineStage = true; + break; + } + } + if (!canDetermineStage) + continue; + } + + RefPtr<EntryPoint> entryPoint = + EntryPoint::create(getLinkage(), makeDeclRef(funcDecl), profile); + + validateEntryPoint(entryPoint, sink); + + // Note: in the case that the user didn't explicitly + // specify entry points and we are instead compiling + // a shader "library," then we do not want to automatically + // combine the entry points into groups in the generated + // `Program`, since that would be slightly too magical. + // + // Instead, each entry point will end up in a singleton + // group, so that its entry-point parameters lay out + // independent of the others. + // + _addEntryPoint(entryPoint); + } +} + + +} // namespace Slang diff --git a/source/slang/slang-module.h b/source/slang/slang-module.h new file mode 100644 index 000000000..7a71f242d --- /dev/null +++ b/source/slang/slang-module.h @@ -0,0 +1,525 @@ +// slang-module.h +#pragma once + +// +// This file provides the `Module` class, which is +// central to many parts of the Slang compiler codebase. +// + +#include "../core/slang-string-util.h" +#include "slang-ast-builder.h" +#include "slang-entry-point.h" +#include "slang-linkable.h" + +namespace Slang +{ + +/// A module of code that has been compiled through the front-end +/// +/// A module comprises all the code from one translation unit (which +/// may span multiple Slang source files), and provides access +/// to both the AST and IR representations of that code. +/// +/// This class serves multiple important roles in the Slang compiler: +/// +/// * this class implements the `slang::IModule` interface from +/// the public Slang API. +/// +/// * this class is the primary output of front-end compilation, +/// and its data is what gets stored/loaded using the `.slang-module` +/// file format. +/// +/// * The checked AST in a `Module` provides all of the information +/// that the front-end uses when checking code that `import`s +/// that module (e.g., the names and signatures of functions defined +/// in the module). +/// +/// * The checked AST is also used to service queries through the +/// Slang reflection API. +/// +/// * the `Module` class is a subclass of `ComponentType` and thus +/// is a unit of linkable code. One or more modules (and other +/// linkable objects) can be combined to form a linked program. +/// +/// * The Slang IR in a `Module` provides all of the information that +/// the back-end uses when generating code for a program/binary +/// that links this module (or any of its entry points). +/// +class Module : public ComponentType, public slang::IModule +{ + typedef ComponentType Super; + +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + ISlangUnknown* getInterface(const Guid& guid); + + + // Forward `IComponentType` methods + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } + + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetCode(targetIndex, outCode, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE + { + return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::link(outLinkedComponentType, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, + targetIndex, + outSharedLibrary, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + findEntryPointByName(char const* name, slang::IEntryPoint** outEntryPoint) SLANG_OVERRIDE + { + if (outEntryPoint == nullptr) + { + return SLANG_E_INVALID_ARG; + } + SLANG_AST_BUILDER_RAII(m_astBuilder); + ComPtr<slang::IEntryPoint> entryPoint(findEntryPointByName(UnownedStringSlice(name))); + if ((!entryPoint)) + return SLANG_FAIL; + + *outEntryPoint = entryPoint.detach(); + return SLANG_OK; + } + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint( + char const* name, + SlangStage stage, + slang::IEntryPoint** outEntryPoint, + ISlangBlob** outDiagnostics) override + { + if (outEntryPoint == nullptr) + { + return SLANG_E_INVALID_ARG; + } + ComPtr<slang::IEntryPoint> entryPoint( + findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics)); + if ((!entryPoint)) + return SLANG_FAIL; + + *outEntryPoint = entryPoint.detach(); + return SLANG_OK; + } + + virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDefinedEntryPointCount() override + { + return (SlangInt32)m_entryPoints.getCount(); + } + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint) override + { + if (index < 0 || index >= m_entryPoints.getCount()) + return SLANG_E_INVALID_ARG; + + if (outEntryPoint == nullptr) + { + return SLANG_E_INVALID_ARG; + } + + ComPtr<slang::IEntryPoint> entryPoint(m_entryPoints[index].Ptr()); + *outEntryPoint = entryPoint.detach(); + return SLANG_OK; + } + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + slang::IComponentType** outLinkedComponentType, + uint32_t count, + slang::CompilerOptionEntry* entries, + ISlangBlob** outDiagnostics) override + { + return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); + } + // + + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) SLANG_OVERRIDE + { + return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata( + entryPointIndex, + targetIndex, + outMetadata, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCompileResult( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::ICompileResult** outCompileResult, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCompileResult( + entryPointIndex, + targetIndex, + outCompileResult, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCompileResult( + SlangInt targetIndex, + slang::ICompileResult** outCompileResult, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetCompileResult(targetIndex, outCompileResult, outDiagnostics); + } + + /// Get a serialized representation of the checked module. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + serialize(ISlangBlob** outSerializedBlob) override; + + /// Write the serialized representation of this module to a file. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL writeToFile(char const* fileName) override; + + /// Get the name of the module. + virtual SLANG_NO_THROW const char* SLANG_MCALL getName() override; + + /// Get the path of the module. + virtual SLANG_NO_THROW const char* SLANG_MCALL getFilePath() override; + + /// Get the unique identity of the module. + virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() override; + + /// Get the number of dependency files that this module depends on. + /// This includes both the explicit source files, as well as any + /// additional files that were transitively referenced (e.g., via + /// a `#include` directive). + virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDependencyFileCount() override; + + /// Get the path to a file this module depends on. + virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath(SlangInt32 index) override; + + + // IModulePrecompileService_Experimental + /// Precompile TU to target language + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + precompileForTarget(SlangCompileTarget target, slang::IBlob** outDiagnostics) override; + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getPrecompiledTargetCode( + SlangCompileTarget target, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr) override; + + virtual SLANG_NO_THROW SlangInt SLANG_MCALL getModuleDependencyCount() SLANG_OVERRIDE; + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getModuleDependency( + SlangInt dependencyIndex, + slang::IModule** outModule, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + + virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE; + + virtual SLANG_NO_THROW slang::DeclReflection* SLANG_MCALL getModuleReflection() SLANG_OVERRIDE; + + void setDigest(SHA1::Digest const& digest) { m_digest = digest; } + SHA1::Digest computeDigest(); + + /// Create a module (initially empty). + Module(Linkage* linkage, ASTBuilder* astBuilder = nullptr); + + /// Get the AST for the module (if it has been parsed) + ModuleDecl* getModuleDecl() { return m_moduleDecl; } + + /// The the IR for the module (if it has been generated) + IRModule* getIRModule() { return m_irModule; } + + /// Get the list of other modules this module depends on + List<Module*> const& getModuleDependencyList() + { + return m_moduleDependencyList.getModuleList(); + } + + /// Get the list of files this module depends on + List<SourceFile*> const& getFileDependencyList() { return m_fileDependencyList.getFileList(); } + + /// Register a module that this module depends on + void addModuleDependency(Module* module); + + /// Register a source file that this module depends on + void addFileDependency(SourceFile* sourceFile); + + void clearFileDependency() { m_fileDependencyList.clear(); } + /// Set the AST for this module. + /// + /// This should only be called once, during creation of the module. + /// + void setModuleDecl(ModuleDecl* moduleDecl); // { m_moduleDecl = moduleDecl; } + + void setName(String name); + void setName(Name* name) { m_name = name; } + Name* getNameObj() { return m_name; } + + void setPathInfo(PathInfo pathInfo) { m_pathInfo = pathInfo; } + + /// Set the IR for this module. + /// + /// This should only be called once, during creation of the module. + /// + void setIRModule(IRModule* irModule) { m_irModule = irModule; } + + Index getEntryPointCount() SLANG_OVERRIDE { return 0; } + RefPtr<EntryPoint> getEntryPoint(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return nullptr; + } + String getEntryPointMangledName(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return String(); + } + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return String(); + } + + Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } + + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE + { + return m_specializationParams.getCount(); + } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE + { + return m_specializationParams[index]; + } + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE; + + List<Module*> const& getModuleDependencies() SLANG_OVERRIDE + { + return m_moduleDependencyList.getModuleList(); + } + List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE + { + return m_fileDependencyList.getFileList(); + } + + /// Given a mangled name finds the exported NodeBase associated with this module. + /// If not found returns nullptr. + Decl* findExportedDeclByMangledName(const UnownedStringSlice& mangledName); + + /// Ensure that the any accelerator(s) used for `findExportedDeclByMangledName` + /// have already been built. + /// + void ensureExportLookupAcceleratorBuilt(); + + Count getExportedDeclCount(); + Decl* getExportedDecl(Index index); + UnownedStringSlice getExportedDeclMangledName(Index index); + + /// Get the ASTBuilder + ASTBuilder* getASTBuilder() { return m_astBuilder; } + + /// Collect information on the shader parameters of the module. + /// + /// This method should only be called once, after the core + /// structured of the module (its AST and IR) have been created, + /// and before any of the `ComponentType` APIs are used. + /// + /// TODO: We might eventually consider a non-stateful approach + /// to constructing a `Module`. + /// + void _collectShaderParams(); + + void _discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets); + void _discoverEntryPointsImpl( + ContainerDecl* containerDecl, + DiagnosticSink* sink, + const List<RefPtr<TargetRequest>>& targets); + + + class ModuleSpecializationInfo : public SpecializationInfo + { + public: + struct GenericArgInfo + { + Decl* paramDecl = nullptr; + Val* argVal = nullptr; + }; + + List<GenericArgInfo> genericArgs; + List<ExpandedSpecializationArg> existentialArgs; + }; + + RefPtr<EntryPoint> findEntryPointByName(UnownedStringSlice const& name); + RefPtr<EntryPoint> findAndCheckEntryPoint( + UnownedStringSlice const& name, + SlangStage stage, + ISlangBlob** outDiagnostics); + + List<RefPtr<EntryPoint>>& getEntryPoints() { return m_entryPoints; } + void _addEntryPoint(EntryPoint* entryPoint); + void _processFindDeclsExportSymbolsRec(Decl* decl); + + // Gets the files that has been included into the module. + Dictionary<SourceFile*, FileDecl*>& getIncludedSourceFileMap() + { + return m_mapSourceFileToFileDecl; + } + +protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; + + RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + +private: + Name* m_name = nullptr; + PathInfo m_pathInfo; + + // The AST for the module + ModuleDecl* m_moduleDecl = nullptr; + + // The IR for the module + RefPtr<IRModule> m_irModule = nullptr; + + List<ShaderParamInfo> m_shaderParams; + SpecializationParams m_specializationParams; + + List<Module*> m_requirements; + + // A digest that uniquely identifies the contents of the module. + SHA1::Digest m_digest; + + // List of modules this module depends on + ModuleDependencyList m_moduleDependencyList; + + // List of source files this module depends on + FileDependencyList m_fileDependencyList; + + // Entry points that were defined in this module + // + // Note: the entry point defined in the module are *not* + // part of the memory image/layout of the module when + // it is considered as an IComponentType. This can be + // a bit confusing, but if all the entry points in the + // module were automatically linked into the component + // type, we'd need a way to access just the global + // scope of the module without the entry points, in + // case we wanted to link a single entry point against + // the global scope. The `Module` type provides exactly + // that "module without its entry points" unit of + // granularity for linking. + // + // This list only exists for lookup purposes, so that + // the user can find an existing entry-point function + // that was defined as part of the module. + // + List<RefPtr<EntryPoint>> m_entryPoints; + + // The builder that owns all of the AST nodes from parsing the source of + // this module. + RefPtr<ASTBuilder> m_astBuilder; + + // Holds map of exported mangled names to symbols. m_mangledExportPool maps names to indices, + // and m_mangledExportSymbols holds the NodeBase* values for each index. + StringSlicePool m_mangledExportPool; + List<Decl*> m_mangledExportSymbols; + + // Source files that have been pulled into the module with `__include`. + Dictionary<SourceFile*, FileDecl*> m_mapSourceFileToFileDecl; + +public: + SLANG_NO_THROW SlangResult SLANG_MCALL disassemble(slang::IBlob** outDisassembledBlob) override + { + if (!outDisassembledBlob) + return SLANG_E_INVALID_ARG; + String disassembly; + this->getIRModule()->getModuleInst()->dump(disassembly); + auto blob = StringUtil::createStringBlob(disassembly); + *outDisassembledBlob = blob.detach(); + return SLANG_OK; + } +}; + +} // namespace Slang diff --git a/source/slang/slang-pass-through.cpp b/source/slang/slang-pass-through.cpp new file mode 100644 index 000000000..268cbc98f --- /dev/null +++ b/source/slang/slang-pass-through.cpp @@ -0,0 +1,275 @@ +// slang-pass-through.cpp +#include "slang-pass-through.h" + +#include "../core/slang-type-text-util.h" +#include "compiler-core/slang-slice-allocator.h" +#include "slang-compiler.h" + +namespace Slang +{ + +void printDiagnosticArg(StringBuilder& sb, PassThroughMode val) +{ + sb << TypeTextUtil::getPassThroughName(SlangPassThrough(val)); +} + +SourceLanguage getDefaultSourceLanguageForDownstreamCompiler(PassThroughMode compiler) +{ + switch (compiler) + { + case PassThroughMode::None: + { + return SourceLanguage::Unknown; + } + case PassThroughMode::Fxc: + case PassThroughMode::Dxc: + { + return SourceLanguage::HLSL; + } + case PassThroughMode::Glslang: + { + return SourceLanguage::GLSL; + } + case PassThroughMode::LLVM: + case PassThroughMode::Clang: + case PassThroughMode::VisualStudio: + case PassThroughMode::Gcc: + case PassThroughMode::GenericCCpp: + { + // These could ingest C, but we only have this function to work out a + // 'default' language to ingest. + return SourceLanguage::CPP; + } + case PassThroughMode::NVRTC: + { + return SourceLanguage::CUDA; + } + case PassThroughMode::Tint: + { + return SourceLanguage::WGSL; + } + case PassThroughMode::SpirvDis: + { + return SourceLanguage::SPIRV; + } + case PassThroughMode::MetalC: + { + return SourceLanguage::Metal; + } + default: + break; + } + SLANG_ASSERT(!"Unknown compiler"); + return SourceLanguage::Unknown; +} + +PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target) +{ + switch (target) + { + // Don't *require* a downstream compiler for source output + case CodeGenTarget::GLSL: + case CodeGenTarget::HLSL: + case CodeGenTarget::CUDASource: + case CodeGenTarget::CPPSource: + case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: + case CodeGenTarget::CSource: + case CodeGenTarget::Metal: + case CodeGenTarget::WGSL: + { + return PassThroughMode::None; + } + case CodeGenTarget::None: + { + return PassThroughMode::None; + } + case CodeGenTarget::WGSLSPIRVAssembly: + case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::SPIRV: + { + return PassThroughMode::SpirvDis; + } + case CodeGenTarget::DXBytecode: + case CodeGenTarget::DXBytecodeAssembly: + { + return PassThroughMode::Fxc; + } + case CodeGenTarget::DXIL: + case CodeGenTarget::DXILAssembly: + { + return PassThroughMode::Dxc; + } + case CodeGenTarget::MetalLib: + case CodeGenTarget::MetalLibAssembly: + { + return PassThroughMode::MetalC; + } + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::ShaderSharedLibrary: + case CodeGenTarget::HostExecutable: + case CodeGenTarget::HostHostCallable: + case CodeGenTarget::HostSharedLibrary: + { + // We need some C/C++ compiler + return PassThroughMode::GenericCCpp; + } + case CodeGenTarget::PTX: + { + return PassThroughMode::NVRTC; + } + case CodeGenTarget::WGSLSPIRV: + { + return PassThroughMode::Tint; + } + default: + break; + } + + SLANG_ASSERT(!"Unhandled target"); + return PassThroughMode::None; +} + + +void reportExternalCompileError( + const char* compilerName, + Severity severity, + SlangResult res, + const UnownedStringSlice& diagnostic, + DiagnosticSink* sink) +{ + StringBuilder builder; + if (compilerName) + { + builder << compilerName << ": "; + } + + if (SLANG_FAILED(res) && res != SLANG_FAIL) + { + { + char tmp[17]; + sprintf_s(tmp, SLANG_COUNT_OF(tmp), "0x%08x", uint32_t(res)); + builder << "Result(" << tmp << ") "; + } + + PlatformUtil::appendResult(res, builder); + } + + if (diagnostic.getLength() > 0) + { + builder.append(diagnostic); + if (!diagnostic.endsWith("\n")) + { + builder.append("\n"); + } + } + + sink->diagnoseRaw(severity, builder.getUnownedSlice()); +} + +void reportExternalCompileError( + const char* compilerName, + SlangResult res, + const UnownedStringSlice& diagnostic, + DiagnosticSink* sink) +{ + // TODO(tfoley): need a better policy for how we translate diagnostics + // back into the Slang world (although we should always try to generate + // HLSL that doesn't produce any diagnostics...) + reportExternalCompileError( + compilerName, + SLANG_FAILED(res) ? Severity::Error : Severity::Warning, + res, + diagnostic, + sink); +} + +static Severity _getDiagnosticSeverity(ArtifactDiagnostic::Severity severity) +{ + switch (severity) + { + case ArtifactDiagnostic::Severity::Warning: + return Severity::Warning; + case ArtifactDiagnostic::Severity::Info: + return Severity::Note; + default: + return Severity::Error; + } +} + +SlangResult passthroughDownstreamDiagnostics( + DiagnosticSink* sink, + IDownstreamCompiler* compiler, + IArtifact* artifact) +{ + auto diagnostics = findAssociatedRepresentation<IArtifactDiagnostics>(artifact); + + if (!diagnostics) + return SLANG_OK; + + if (diagnostics->getCount()) + { + StringBuilder compilerText; + DownstreamCompilerUtil::appendAsText(compiler->getDesc(), compilerText); + + StringBuilder builder; + + auto const diagnosticCount = diagnostics->getCount(); + for (Index i = 0; i < diagnosticCount; ++i) + { + const auto& diagnostic = *diagnostics->getAt(i); + + builder.clear(); + + const Severity severity = _getDiagnosticSeverity(diagnostic.severity); + + if (diagnostic.filePath.count == 0 && diagnostic.location.line == 0 && + severity == Severity::Note) + { + // If theres no filePath line number and it's info, output severity and text alone + builder << getSeverityName(severity) << " : "; + } + else + { + if (diagnostic.filePath.count) + { + builder << asStringSlice(diagnostic.filePath); + } + + if (diagnostic.location.line) + { + builder << "(" << diagnostic.location.line << ")"; + } + + builder << ": "; + + if (diagnostic.stage == ArtifactDiagnostic::Stage::Link) + { + builder << "link "; + } + + builder << getSeverityName(severity); + builder << " " << asStringSlice(diagnostic.code) << ": "; + } + + builder << asStringSlice(diagnostic.text); + reportExternalCompileError( + compilerText.getBuffer(), + severity, + SLANG_OK, + builder.getUnownedSlice(), + sink); + } + } + + // If any errors are emitted, then we are done + if (diagnostics->hasOfAtLeastSeverity(ArtifactDiagnostic::Severity::Error)) + { + return SLANG_FAIL; + } + + return SLANG_OK; +} + + +} // namespace Slang diff --git a/source/slang/slang-pass-through.h b/source/slang/slang-pass-through.h new file mode 100644 index 000000000..8cb350572 --- /dev/null +++ b/source/slang/slang-pass-through.h @@ -0,0 +1,91 @@ +// slang-pass-through.h +#pragma once + +// +// This file gathers together declarations for utility code +// related to the concept of "pass-through" compilation. +// +// Note that in the Slang codebase there is an unfortunate +// conflation of terminology (and a resulting conflation +// in a lot of the implementation logic) between the cases +// of true *pass-through* compilation and the logic related +// to *downstream* compilers: +// +// * While the Slang compiler is architected to support direct +// generation of binary output code (e.g., there is support +// for emitting SPIR-V directly from Slang IR), many targets +// are supported by first generating intermediate source code +// from Slang IR and then invoking a *downstream* compiler on +// that intermediate source code to produce output binaries. +// This is an important kind of compilation flow that Slang +// was designed to support. +// +// * In contrast, true *pass-through* compilation is a legacy +// feature of `slangc` that exists almost entirely to support +// some of the earliest test cases that were written for the +// compiler. In true pass-through mode, `slangc` skips invoking +// large parts of the Slang compiler (the entire front-end, +// along with all of the back-end up to the point that intermediate +// source code would be generated) and then uses the *input* +// source as if it was the *intermediate* code for the "last mile" +// of code generation (invoking a downstream compiler). This +// feature can and *should* be deprecate and removed, because +// the complexity it creates for the rest of the compiler is +// no longer worth it. +// +// This file may contain a mix of declarations used for one or +// both of the above purposes, just because the terminology used +// in the codebase isn't always precise or clear. Over time the +// declarations here can and should be more clearly partitioned +// so that we can distinguish the essential (downstream compilation) +// parts, and the parts that should eventually get removed (true +// pass-through compilation). +// + +#include "../core/slang-string.h" +#include "slang-target.h" + +namespace Slang +{ + +enum class PassThroughMode : SlangPassThroughIntegral +{ + None = SLANG_PASS_THROUGH_NONE, ///< don't pass through: use Slang compiler + Fxc = SLANG_PASS_THROUGH_FXC, ///< pass through HLSL to `D3DCompile` API + Dxc = SLANG_PASS_THROUGH_DXC, ///< pass through HLSL to `IDxcCompiler` API + Glslang = SLANG_PASS_THROUGH_GLSLANG, ///< pass through GLSL to `glslang` library + SpirvDis = SLANG_PASS_THROUGH_SPIRV_DIS, ///< pass through spirv-dis + Clang = SLANG_PASS_THROUGH_CLANG, ///< Pass through clang compiler + VisualStudio = SLANG_PASS_THROUGH_VISUAL_STUDIO, ///< Visual studio compiler + Gcc = SLANG_PASS_THROUGH_GCC, ///< Gcc compiler + GenericCCpp = SLANG_PASS_THROUGH_GENERIC_C_CPP, ///< Generic C/C++ compiler + NVRTC = SLANG_PASS_THROUGH_NVRTC, ///< NVRTC CUDA compiler + LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler' + SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt + MetalC = SLANG_PASS_THROUGH_METAL, + Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API + SpirvLink = SLANG_PASS_THROUGH_SPIRV_LINK, ///< pass through spirv to spirv-link + CountOf = SLANG_PASS_THROUGH_COUNT_OF, +}; +void printDiagnosticArg(StringBuilder& sb, PassThroughMode val); + +/// Given a target returns the required downstream compiler +PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target); + +/// Given a target returns a downstream compiler the prelude should be taken from. +SourceLanguage getDefaultSourceLanguageForDownstreamCompiler(PassThroughMode compiler); + +/* Report an error appearing from external compiler to the diagnostic sink error to the diagnostic +sink. +@param compilerName The name of the compiler the error came for (or nullptr if not known) +@param res Result associated with the error. The error code will be reported. (Can take HRESULT - +and will expand to string if known) +@param diagnostic The diagnostic string associated with the compile failure +@param sink The diagnostic sink to report to */ +void reportExternalCompileError( + const char* compilerName, + SlangResult res, + const UnownedStringSlice& diagnostic, + DiagnosticSink* sink); + +} // namespace Slang diff --git a/source/slang/slang-profile.cpp b/source/slang/slang-profile.cpp index 7e12b4f4e..f7f6b7bdd 100644 --- a/source/slang/slang-profile.cpp +++ b/source/slang/slang-profile.cpp @@ -4,6 +4,190 @@ namespace Slang { +Profile Profile::lookUp(UnownedStringSlice const& name) +{ +#define PROFILE(TAG, NAME, STAGE, VERSION) \ + if (name == UnownedTerminatedStringSlice(#NAME)) \ + return Profile::TAG; +#define PROFILE_ALIAS(TAG, DEF, NAME) \ + if (name == UnownedTerminatedStringSlice(#NAME)) \ + return Profile::TAG; +#include "slang-profile-defs.h" + + return Profile::Unknown; +} + +Profile Profile::lookUp(char const* name) +{ + return lookUp(UnownedTerminatedStringSlice(name)); +} + +CapabilitySet Profile::getCapabilityName() +{ + List<CapabilityName> result; + switch (getVersion()) + { +#define PROFILE_VERSION(TAG, NAME) \ + case ProfileVersion::TAG: \ + result.add(CapabilityName::TAG); \ + break; +#include "slang-profile-defs.h" + default: + break; + } + switch (getStage()) + { +#define PROFILE_STAGE(TAG, NAME, VAL) \ + case Stage::TAG: \ + result.add(CapabilityName::NAME); \ + break; +#include "slang-profile-defs.h" + default: + break; + } + + CapabilitySet resultSet = CapabilitySet(result); + for (auto i : this->additionalCapabilities) + resultSet.join(i); + return resultSet; +} + +char const* Profile::getName() +{ + switch (raw) + { + default: + return "unknown"; + +#define PROFILE(TAG, NAME, STAGE, VERSION) \ + case Profile::TAG: \ + return #NAME; +#define PROFILE_ALIAS(TAG, DEF, NAME) /* empty */ +#include "slang-profile-defs.h" + } +} + +static const StageInfo kStages[] = { +#define PROFILE_STAGE(ID, NAME, ENUM) {#NAME, Stage::ID}, + +#define PROFILE_STAGE_ALIAS(ID, NAME, VAL) {#NAME, Stage::ID}, + +#include "slang-profile-defs.h" +}; + +ConstArrayView<StageInfo> getStageInfos() +{ + return makeConstArrayView(kStages); +} + +Stage findStageByName(String const& name) +{ + for (auto entry : kStages) + { + if (name == entry.name) + { + return entry.stage; + } + } + + return Stage::Unknown; +} + +UnownedStringSlice getStageText(Stage stage) +{ + for (auto entry : kStages) + { + if (stage == entry.stage) + { + return UnownedStringSlice(entry.name); + } + } + return UnownedStringSlice(); +} + +Stage getStageFromAtom(CapabilityAtom atom) +{ + switch (atom) + { + case CapabilityAtom::vertex: + return Stage::Vertex; + case CapabilityAtom::hull: + return Stage::Hull; + case CapabilityAtom::domain: + return Stage::Domain; + case CapabilityAtom::geometry: + return Stage::Geometry; + case CapabilityAtom::fragment: + return Stage::Fragment; + case CapabilityAtom::compute: + return Stage::Compute; + case CapabilityAtom::_mesh: + return Stage::Mesh; + case CapabilityAtom::_amplification: + return Stage::Amplification; + case CapabilityAtom::_anyhit: + return Stage::AnyHit; + case CapabilityAtom::_closesthit: + return Stage::ClosestHit; + case CapabilityAtom::_intersection: + return Stage::Intersection; + case CapabilityAtom::_raygen: + return Stage::RayGeneration; + case CapabilityAtom::_miss: + return Stage::Miss; + case CapabilityAtom::_callable: + return Stage::Callable; + case CapabilityAtom::dispatch: + return Stage::Dispatch; + default: + SLANG_UNEXPECTED("unknown stage atom"); + UNREACHABLE_RETURN(Stage::Unknown); + } +} + +CapabilityAtom getAtomFromStage(Stage stage) +{ + // Convert Slang::Stage to CapabilityAtom. + // Note that capabilities do not share the same values as Slang::Stage + // and must be explicitly converted. + switch (stage) + { + case Stage::Compute: + return CapabilityAtom::compute; + case Stage::Vertex: + return CapabilityAtom::vertex; + case Stage::Fragment: + return CapabilityAtom::fragment; + case Stage::Geometry: + return CapabilityAtom::geometry; + case Stage::Hull: + return CapabilityAtom::hull; + case Stage::Domain: + return CapabilityAtom::domain; + case Stage::Mesh: + return CapabilityAtom::_mesh; + case Stage::Amplification: + return CapabilityAtom::_amplification; + case Stage::RayGeneration: + return CapabilityAtom::_raygen; + case Stage::AnyHit: + return CapabilityAtom::_anyhit; + case Stage::ClosestHit: + return CapabilityAtom::_closesthit; + case Stage::Miss: + return CapabilityAtom::_miss; + case Stage::Intersection: + return CapabilityAtom::_intersection; + case Stage::Callable: + return CapabilityAtom::_callable; + case Stage::Dispatch: + return CapabilityAtom::dispatch; + default: + SLANG_UNEXPECTED("unknown stage"); + UNREACHABLE_RETURN(CapabilityAtom::Invalid); + } +} + ProfileFamily getProfileFamily(ProfileVersion version) { switch (version) @@ -59,5 +243,93 @@ void printDiagnosticArg(StringBuilder& sb, ProfileVersion val) sb << Profile(val).getName(); } +String getHLSLProfileName(Profile profile) +{ + switch (profile.getFamily()) + { + case ProfileFamily::DX: + // Profile version is a DX one, so stick with it. + break; + + default: + // Profile is a non-DX profile family, so we need to try + // to clobber it with something to get a default. + // + // TODO: This is a huge hack... + profile.setVersion(ProfileVersion::DX_5_1); + break; + } + + char const* stagePrefix = nullptr; + switch (profile.getStage()) + { + // Note: All of the raytracing-related stages require + // compiling for a `lib_*` profile, even when only a + // single entry point is present. + // + // We also go ahead and use this target in any case + // where we don't know the actual stage to compiel for, + // as a fallback option. + // + // TODO: We also want to use this option when compiling + // multiple entry points to a DXIL library. + // + default: + stagePrefix = "lib"; + break; + + // The traditional rasterization pipeline and compute + // shaders all have custom profile names that identify + // both the stage and shader model, which need to be + // used when compiling a single entry point. + // +#define CASE(NAME, PREFIX) \ + case Stage::NAME: \ + stagePrefix = #PREFIX; \ + break + CASE(Vertex, vs); + CASE(Hull, hs); + CASE(Domain, ds); + CASE(Geometry, gs); + CASE(Fragment, ps); + CASE(Compute, cs); + CASE(Amplification, as); + CASE(Mesh, ms); +#undef CASE + } + + char const* versionSuffix = nullptr; + switch (profile.getVersion()) + { +#define CASE(TAG, SUFFIX) \ + case ProfileVersion::TAG: \ + versionSuffix = #SUFFIX; \ + break + CASE(DX_4_0, _4_0); + CASE(DX_4_1, _4_1); + CASE(DX_5_0, _5_0); + CASE(DX_5_1, _5_1); + CASE(DX_6_0, _6_0); + CASE(DX_6_1, _6_1); + CASE(DX_6_2, _6_2); + CASE(DX_6_3, _6_3); + CASE(DX_6_4, _6_4); + CASE(DX_6_5, _6_5); + CASE(DX_6_6, _6_6); + CASE(DX_6_7, _6_7); + CASE(DX_6_8, _6_8); + CASE(DX_6_9, _6_9); +#undef CASE + + default: + return "unknown"; + } + + String result; + result.append(stagePrefix); + result.append(versionSuffix); + return result; +} + } // namespace Slang diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h index ca7b8b2ae..96c968aa8 100644 --- a/source/slang/slang-profile.h +++ b/source/slang/slang-profile.h @@ -130,6 +130,8 @@ UnownedStringSlice getStageText(Stage stage); Stage getStageFromAtom(CapabilityAtom atom); CapabilityAtom getAtomFromStage(Stage stage); +String getHLSLProfileName(Profile profile); + } // namespace Slang #endif diff --git a/source/slang/slang-serialize-ast.cpp b/source/slang/slang-serialize-ast.cpp index 325aa3244..261436e38 100644 --- a/source/slang/slang-serialize-ast.cpp +++ b/source/slang/slang-serialize-ast.cpp @@ -1234,7 +1234,6 @@ void serialize(ASTSerializer const& serializer, ValNodeOperand& value) // #if 0 // FIDDLE TEMPLATE: %for _,T in ipairs(astStructTypes) do -% TRACE(T) /// Fossilized representation of a value of type `$T` struct Fossilized_$T % if T.directSuperClass then diff --git a/source/slang/slang-session.cpp b/source/slang/slang-session.cpp new file mode 100644 index 000000000..da39949da --- /dev/null +++ b/source/slang/slang-session.cpp @@ -0,0 +1,2068 @@ +// slang-session.cpp +#include "slang-session.h" + +#include "compiler-core/slang-artifact-util.h" +#include "slang-check-impl.h" +#include "slang-compiler.h" +#include "slang-lower-to-ir.h" +#include "slang-mangle.h" +#include "slang-options.h" +#include "slang-parser.h" +#include "slang-preprocessor.h" +#include "slang-serialize-ast.h" +#include "slang-serialize-container.h" +#include "slang-serialize-ir.h" + +namespace Slang +{ + +Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage) + : m_session(session) + , m_retainedSession(session) + , m_sourceManager(&m_defaultSourceManager) + , m_astBuilder(astBuilder) + , m_cmdLineContext(new CommandLineContext()) + , m_stringSlicePool(StringSlicePool::Style::Default) +{ + namePool = session->getNamePool(); + + m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr); + + setFileSystem(nullptr); + + // Copy of the built in linkages modules + if (builtinLinkage) + { + for (const auto& nameToMod : builtinLinkage->mapNameToLoadedModules) + mapNameToLoadedModules.add(nameToMod); + } + + m_semanticsForReflection = new SharedSemanticsContext(this, nullptr, nullptr); +} + +SharedSemanticsContext* Linkage::getSemanticsForReflection() +{ + return m_semanticsForReflection.get(); +} + +ISlangUnknown* Linkage::getInterface(const Guid& guid) +{ + if (guid == ISlangUnknown::getTypeGuid() || guid == ISession::getTypeGuid()) + return asExternal(this); + + return nullptr; +} + +Linkage::~Linkage() +{ + // Upstream type checking cache. + if (m_typeCheckingCache) + { + auto globalSession = getSessionImpl(); + std::lock_guard<std::mutex> lock(globalSession->m_typeCheckingCacheMutex); + if (!globalSession->m_typeCheckingCache || + globalSession->getTypeCheckingCache()->resolvedOperatorOverloadCache.getCount() < + getTypeCheckingCache()->resolvedOperatorOverloadCache.getCount()) + { + globalSession->m_typeCheckingCache = m_typeCheckingCache; + getTypeCheckingCache()->version++; + } + destroyTypeCheckingCache(); + } +} + +SearchDirectoryList& Linkage::getSearchDirectories() +{ + auto list = m_optionSet.getArray(CompilerOptionName::Include); + if (list.getCount() != searchDirectoryCache.searchDirectories.getCount()) + { + searchDirectoryCache.searchDirectories.clear(); + for (auto dir : list) + searchDirectoryCache.searchDirectories.add(SearchDirectory(dir.stringValue)); + } + return searchDirectoryCache; +} + +TypeCheckingCache* Linkage::getTypeCheckingCache() +{ + if (!m_typeCheckingCache) + { + m_typeCheckingCache = new TypeCheckingCache(); + } + return static_cast<TypeCheckingCache*>(m_typeCheckingCache.get()); +} + +void Linkage::destroyTypeCheckingCache() +{ + m_typeCheckingCache = nullptr; +} + +SLANG_NO_THROW slang::IGlobalSession* SLANG_MCALL Linkage::getGlobalSession() +{ + return asExternal(getSessionImpl()); +} + +void Linkage::addTarget(slang::TargetDesc const& desc) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + auto targetIndex = addTarget(CodeGenTarget(desc.format)); + auto target = targets[targetIndex]; + + auto& optionSet = target->getOptionSet(); + optionSet.inheritFrom(m_optionSet); + + optionSet.set(CompilerOptionName::FloatingPointMode, FloatingPointMode(desc.floatingPointMode)); + optionSet.addTargetFlags(desc.flags); + optionSet.setProfile(Profile(desc.profile)); + optionSet.set(CompilerOptionName::LineDirectiveMode, LineDirectiveMode(desc.lineDirectiveMode)); + optionSet.set(CompilerOptionName::GLSLForceScalarLayout, desc.forceGLSLScalarBufferLayout); + + CompilerOptionSet targetOptions; + targetOptions.load(desc.compilerOptionEntryCount, desc.compilerOptionEntries); + optionSet.overrideWith(targetOptions); +} + +SLANG_NO_THROW slang::IModule* SLANG_MCALL +Linkage::loadModule(const char* moduleName, slang::IBlob** outDiagnostics) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + if (isInLanguageServer()) + { + sink.setFlags(DiagnosticSink::Flag::HumaneLoc | DiagnosticSink::Flag::LanguageServer); + } + + try + { + auto name = getNamePool()->getName(moduleName); + + auto module = findOrImportModule(name, SourceLoc(), &sink); + sink.getBlobIfNeeded(outDiagnostics); + + return asExternal(module); + } + catch (const AbortCompilationException& e) + { + outputExceptionDiagnostic(e, sink, outDiagnostics); + return nullptr; + } + catch (const Exception& e) + { + outputExceptionDiagnostic(e, sink, outDiagnostics); + return nullptr; + } + catch (...) + { + outputExceptionDiagnostic(sink, outDiagnostics); + return nullptr; + } +} + +slang::IModule* Linkage::loadModuleFromBlob( + const char* moduleName, + const char* path, + slang::IBlob* source, + ModuleBlobType blobType, + slang::IBlob** outDiagnostics) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + if (isInLanguageServer()) + { + sink.setFlags(DiagnosticSink::Flag::HumaneLoc | DiagnosticSink::Flag::LanguageServer); + } + + + try + { + auto getDigestStr = [](auto x) + { + DigestBuilder<SHA1> digestBuilder; + digestBuilder.append(x); + return digestBuilder.finalize().toString(); + }; + + String moduleNameStr = moduleName; + if (!moduleName) + moduleNameStr = getDigestStr(source); + + auto name = getNamePool()->getName(moduleNameStr); + RefPtr<LoadedModule> loadedModule; + if (mapNameToLoadedModules.tryGetValue(name, loadedModule)) + { + return loadedModule; + } + String pathStr = path; + if (pathStr.getLength() == 0) + { + // If path is empty, use a digest from source as path. + pathStr = getDigestStr(source); + } + auto pathInfo = PathInfo::makeFromString(pathStr); + if (File::exists(pathStr)) + { + String cannonicalPath; + if (SLANG_SUCCEEDED(Path::getCanonical(pathStr, cannonicalPath))) + { + pathInfo = PathInfo::makeNormal(pathStr, cannonicalPath); + } + } + RefPtr<Module> module = + loadModuleImpl(name, pathInfo, source, SourceLoc(), &sink, nullptr, blobType); + sink.getBlobIfNeeded(outDiagnostics); + return asExternal(module.get()); + } + catch (const AbortCompilationException& e) + { + outputExceptionDiagnostic(e, sink, outDiagnostics); + return nullptr; + } + catch (const Exception& e) + { + outputExceptionDiagnostic(e, sink, outDiagnostics); + return nullptr; + } + catch (...) + { + outputExceptionDiagnostic(sink, outDiagnostics); + return nullptr; + } +} + +SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSource( + const char* moduleName, + const char* path, + slang::IBlob* source, + slang::IBlob** outDiagnostics) +{ + return loadModuleFromBlob(moduleName, path, source, ModuleBlobType::Source, outDiagnostics); +} + +SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSourceString( + const char* moduleName, + const char* path, + const char* source, + slang::IBlob** outDiagnostics) +{ + auto sourceBlob = StringBlob::create(UnownedStringSlice(source)); + return loadModuleFromSource(moduleName, path, sourceBlob.get(), outDiagnostics); +} + +SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromIRBlob( + const char* moduleName, + const char* path, + slang::IBlob* source, + slang::IBlob** outDiagnostics) +{ + return loadModuleFromBlob(moduleName, path, source, ModuleBlobType::IR, outDiagnostics); +} + +SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::loadModuleInfoFromIRBlob( + slang::IBlob* source, + SlangInt& outModuleVersion, + const char*& outModuleCompilerVersion, + const char*& outModuleName) +{ + // We start by reading the content of the file as + // an in-memory RIFF container. + // + auto rootChunk = RIFF::RootChunk::getFromBlob(source); + if (!rootChunk) + { + return SLANG_FAIL; + } + + auto moduleChunk = ModuleChunk::find(rootChunk); + if (!moduleChunk) + { + return SLANG_FAIL; + } + + auto irChunk = moduleChunk->findIR(); + if (!irChunk) + { + return SLANG_FAIL; + } + + RefPtr<IRModule> irModule; + String compilerVersion; + UInt version; + String name; + SLANG_RETURN_ON_FAIL(readSerializedModuleInfo(irChunk, compilerVersion, version, name)); + const auto compilerVersionSlice = m_stringSlicePool.addAndGetSlice(compilerVersion); + const auto nameSlice = m_stringSlicePool.addAndGetSlice(name); + outModuleCompilerVersion = compilerVersionSlice.begin(); + outModuleName = nameSlice.begin(); + outModuleVersion = SlangInt(version); + + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompositeComponentType( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + slang::IComponentType** outCompositeComponentType, + ISlangBlob** outDiagnostics) +{ + if (outCompositeComponentType == nullptr) + return SLANG_E_INVALID_ARG; + + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + // Attempting to create a "composite" of just one component type should + // just return the component type itself, to avoid redundant work. + // + if (componentTypeCount == 1) + { + auto componentType = componentTypes[0]; + componentType->addRef(); + *outCompositeComponentType = componentType; + return SLANG_OK; + } + + DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + List<RefPtr<ComponentType>> childComponents; + for (Int cc = 0; cc < componentTypeCount; ++cc) + { + childComponents.add(asInternal(componentTypes[cc])); + } + + RefPtr<ComponentType> composite = CompositeComponentType::create(this, childComponents); + + sink.getBlobIfNeeded(outDiagnostics); + + *outCompositeComponentType = asExternal(composite.detach()); + return SLANG_OK; +} + +SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType( + slang::TypeReflection* inUnspecializedType, + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + auto unspecializedType = asInternal(inUnspecializedType); + + List<Type*> typeArgs; + + for (Int ii = 0; ii < specializationArgCount; ++ii) + { + auto& arg = specializationArgs[ii]; + if (arg.kind != slang::SpecializationArg::Kind::Type) + return nullptr; + + typeArgs.add(asInternal(arg.type)); + } + + DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer); + auto specializedType = + specializeType(unspecializedType, typeArgs.getCount(), typeArgs.getBuffer(), &sink); + sink.getBlobIfNeeded(outDiagnostics); + + return asExternal(specializedType); +} + +DeclRef<GenericDecl> getGenericParentDeclRef( + ASTBuilder* astBuilder, + SemanticsVisitor* visitor, + DeclRef<Decl> declRef) +{ + // Create substituted parent decl ref. + auto decl = declRef.getDecl(); + + while (decl && !as<GenericDecl>(decl)) + { + decl = decl->parentDecl; + } + + if (!decl) + { + // No generic parent + return DeclRef<GenericDecl>(); + } + + auto genericDecl = as<GenericDecl>(decl); + auto genericDeclRef = + createDefaultSubstitutionsIfNeeded(astBuilder, visitor, DeclRef(genericDecl)) + .as<GenericDecl>(); + return substituteDeclRef(SubstitutionSet(declRef), astBuilder, genericDeclRef) + .as<GenericDecl>(); +} + +bool Linkage::isSpecialized(DeclRef<Decl> declRef) +{ + // For now, we only support two 'states': fully applied or not at all. + // If we add support for partial specialization, we will need to update this logic. + // + // If it's not specialized, then declRef will be the one with default substitutions. + // + SemanticsVisitor visitor(getSemanticsForReflection()); + + auto decl = declRef.getDecl(); + while (decl && !as<GenericDecl>(decl)) + { + decl = decl->parentDecl; + } + + if (!decl) + return true; // no generics => always specialized + + auto defaultArgs = getDefaultSubstitutionArgs(getASTBuilder(), &visitor, as<GenericDecl>(decl)); + auto currentArgs = + SubstitutionSet(declRef).findGenericAppDeclRef(as<GenericDecl>(decl))->getArgs(); + + if (defaultArgs.getCount() != currentArgs.getCount()) // should really never happen. + return true; + + for (Index i = 0; i < defaultArgs.getCount(); ++i) + { + if (defaultArgs[i] != currentArgs[i]) + return true; + } + + return false; +} + +bool isFuncGeneric(DeclRef<Decl> declRef) +{ + if (auto funcDecl = as<FuncDecl>(declRef.getDecl())) + { + if (funcDecl->parentDecl && as<GenericDecl>(funcDecl->parentDecl)) + { + return true; + } + } + + return false; +} + +DeclRef<Decl> Linkage::specializeWithArgTypes( + Expr* funcExpr, + List<Type*> argTypes, + DiagnosticSink* sink) +{ + SemanticsVisitor visitor(getSemanticsForReflection()); + SemanticsVisitor::ExprLocalScope scope; + visitor = visitor.withSink(sink).withExprLocalScope(&scope); + + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + if (auto declRefFuncExpr = as<DeclRefExpr>(funcExpr)) + { + if (isFuncGeneric(declRefFuncExpr->declRef) && !isSpecialized(declRefFuncExpr->declRef)) + { + if (auto genericDeclRef = getGenericParentDeclRef( + getCurrentASTBuilder(), + &visitor, + declRefFuncExpr->declRef)) + { + auto genericDeclRefExpr = getCurrentASTBuilder()->create<DeclRefExpr>(); + genericDeclRefExpr->declRef = genericDeclRef; + funcExpr = genericDeclRefExpr; + } + } + } + + List<Expr*> argExprs; + for (SlangInt aa = 0; aa < argTypes.getCount(); ++aa) + { + auto argType = argTypes[aa]; + + // Create an 'empty' expr with the given type. Ideally, the expression itself should not + // matter only its checked type. + // + auto argExpr = getCurrentASTBuilder()->create<VarExpr>(); + argExpr->type = argType; + argExpr->type.isLeftValue = true; + argExprs.add(argExpr); + } + + // Construct invoke expr. + auto invokeExpr = getCurrentASTBuilder()->create<InvokeExpr>(); + invokeExpr->functionExpr = funcExpr; + invokeExpr->arguments = argExprs; + + auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr); + + return as<DeclRefExpr>(as<InvokeExpr>(checkedInvokeExpr)->functionExpr)->declRef; +} + + +DeclRef<Decl> Linkage::specializeGeneric( + DeclRef<Decl> declRef, + List<Expr*> argExprs, + DiagnosticSink* sink) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + SLANG_ASSERT(declRef); + + SemanticsVisitor visitor(getSemanticsForReflection()); + visitor = visitor.withSink(sink); + + auto genericDeclRef = getGenericParentDeclRef(getASTBuilder(), &visitor, declRef); + + DeclRefExpr* declRefExpr = getASTBuilder()->create<DeclRefExpr>(); + declRefExpr->declRef = genericDeclRef; + + GenericAppExpr* genericAppExpr = getASTBuilder()->create<GenericAppExpr>(); + genericAppExpr->functionExpr = declRefExpr; + genericAppExpr->arguments = argExprs; + + auto specializedDeclRef = + as<DeclRefExpr>(visitor.checkGenericAppWithCheckedArgs(genericAppExpr))->declRef; + + return specializedDeclRef; +} + +SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL Linkage::getTypeLayout( + slang::TypeReflection* inType, + SlangInt targetIndex, + slang::LayoutRules rules, + ISlangBlob** outDiagnostics) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + auto type = asInternal(inType); + + if (targetIndex < 0 || targetIndex >= targets.getCount()) + return nullptr; + + auto target = targets[targetIndex]; + + // TODO: We need a way to pass through the layout rules + // that the user requested (e.g., constant buffers vs. + // structured buffer rules). Right now the API only + // exposes a single case, so this isn't a big deal. + // + SLANG_UNUSED(rules); + + auto typeLayout = target->getTypeLayout(type, rules); + + // TODO: We currently don't have a path for capturing + // errors that occur during layout (e.g., types that + // are invalid because of target-specific layout constraints). + // + SLANG_UNUSED(outDiagnostics); + + return asExternal(typeLayout); +} + +SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType( + slang::TypeReflection* inType, + slang::ContainerType containerType, + ISlangBlob** outDiagnostics) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + auto type = asInternal(inType); + + Type* containerTypeReflection = nullptr; + ContainerTypeKey key = {inType, containerType}; + if (!m_containerTypes.tryGetValue(key, containerTypeReflection)) + { + switch (containerType) + { + case slang::ContainerType::ConstantBuffer: + { + SemanticsVisitor visitor(getSemanticsForReflection()); + auto layoutType = getASTBuilder()->getDefaultLayoutType(); + Type* cbType = visitor.getConstantBufferType(type, layoutType); + containerTypeReflection = cbType; + } + break; + case slang::ContainerType::ParameterBlock: + { + ParameterBlockType* pbType = getASTBuilder()->getParameterBlockType(type); + containerTypeReflection = pbType; + } + break; + case slang::ContainerType::StructuredBuffer: + { + HLSLStructuredBufferType* sbType = getASTBuilder()->getStructuredBufferType(type); + containerTypeReflection = sbType; + } + break; + case slang::ContainerType::UnsizedArray: + { + ArrayExpressionType* arrType = getASTBuilder()->getArrayType(type, nullptr); + containerTypeReflection = arrType; + } + break; + default: + containerTypeReflection = type; + break; + } + + m_containerTypes.add(key, containerTypeReflection); + } + + SLANG_UNUSED(outDiagnostics); + + return asExternal(containerTypeReflection); +} + +SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getDynamicType() +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + return asExternal(getASTBuilder()->getSharedASTBuilder()->getDynamicType()); +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +Linkage::getTypeRTTIMangledName(slang::TypeReflection* type, ISlangBlob** outNameBlob) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + auto internalType = asInternal(type); + if (auto declRefType = as<DeclRefType>(internalType)) + { + auto name = getMangledName(m_astBuilder, declRefType->getDeclRef()); + Slang::ComPtr<ISlangBlob> blob = Slang::StringUtil::createStringBlob(name); + *outNameBlob = blob.detach(); + return SLANG_OK; + } + return SLANG_FAIL; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessMangledName( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + ISlangBlob** outNameBlob) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + auto subType = asInternal(type); + auto supType = asInternal(interfaceType); + auto name = getMangledNameForConformanceWitness(m_astBuilder, subType, supType); + Slang::ComPtr<ISlangBlob> blob = Slang::StringUtil::createStringBlob(name); + *outNameBlob = blob.detach(); + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessSequentialID( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + uint32_t* outId) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + auto subType = asInternal(type); + auto supType = asInternal(interfaceType); + + if (!subType || !supType) + return SLANG_FAIL; + + auto name = getMangledNameForConformanceWitness(m_astBuilder, subType, supType); + auto interfaceName = getMangledTypeName(m_astBuilder, supType); + uint32_t resultIndex = 0; + if (mapMangledNameToRTTIObjectIndex.tryGetValue(name, resultIndex)) + { + if (outId) + *outId = resultIndex; + return SLANG_OK; + } + auto idAllocator = mapInterfaceMangledNameToSequentialIDCounters.tryGetValue(interfaceName); + if (!idAllocator) + { + mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0; + idAllocator = mapInterfaceMangledNameToSequentialIDCounters.tryGetValue(interfaceName); + } + resultIndex = (*idAllocator); + ++(*idAllocator); + mapMangledNameToRTTIObjectIndex[name] = resultIndex; + if (outId) + *outId = resultIndex; + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getDynamicObjectRTTIBytes( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + uint32_t* outBuffer, + uint32_t bufferSize) +{ + // Slang RTTI header format: + // byte 0-7: pointer to RTTI struct describing the type. (not used for now, set to 1 for valid + // types, and 0 to represent null). + // byte 8-11: 32-bit sequential ID of the type conformance witness. + // byte 12-15: unused. + + if (bufferSize < 16) + return SLANG_E_BUFFER_TOO_SMALL; + + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + SLANG_RETURN_ON_FAIL(getTypeConformanceWitnessSequentialID(type, interfaceType, outBuffer + 2)); + + // Make the RTTI part non zero. + outBuffer[0] = 1; + + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentType( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + slang::ITypeConformance** outConformanceComponentType, + SlangInt conformanceIdOverride, + ISlangBlob** outDiagnostics) +{ + if (outConformanceComponentType == nullptr) + return SLANG_E_INVALID_ARG; + + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + RefPtr<TypeConformance> result; + DiagnosticSink sink; + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + + try + { + SemanticsVisitor visitor(getSemanticsForReflection()); + visitor = visitor.withSink(&sink); + + auto witness = visitor.isSubtype( + (Slang::Type*)type, + (Slang::Type*)interfaceType, + IsSubTypeOptions::None); + if (auto subtypeWitness = as<SubtypeWitness>(witness)) + { + result = new TypeConformance(this, subtypeWitness, conformanceIdOverride, &sink); + } + } + catch (...) + { + } + sink.getBlobIfNeeded(outDiagnostics); + bool success = (result != nullptr); + *outConformanceComponentType = result.detach(); + return success ? SLANG_OK : SLANG_FAIL; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +Linkage::createCompileRequest(SlangCompileRequest** outCompileRequest) +{ + auto compileRequest = new EndToEndCompileRequest(this); + compileRequest->addRef(); + *outCompileRequest = asExternal(compileRequest); + return SLANG_OK; +} + +SLANG_NO_THROW SlangInt SLANG_MCALL Linkage::getLoadedModuleCount() +{ + return loadedModulesList.getCount(); +} + +SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::getLoadedModule(SlangInt index) +{ + if (index >= 0 && index < loadedModulesList.getCount()) + return loadedModulesList[index].get(); + return nullptr; +} + +void Linkage::buildHash(DigestBuilder<SHA1>& builder, SlangInt targetIndex) +{ + // Add the Slang compiler version to the hash + auto version = String(getBuildTagString()); + builder.append(version); + + // Add compiler options, including search path, preprocessor includes, etc. + m_optionSet.buildHash(builder); + + auto addTargetDigest = [&](TargetRequest* targetReq) + { + targetReq->getOptionSet().buildHash(builder); + + const PassThroughMode passThroughMode = + getDownstreamCompilerRequiredForTarget(targetReq->getTarget()); + const SourceLanguage sourceLanguage = + getDefaultSourceLanguageForDownstreamCompiler(passThroughMode); + + // Add prelude for the given downstream compiler. + ComPtr<ISlangBlob> prelude; + getGlobalSession()->getLanguagePrelude( + (SlangSourceLanguage)sourceLanguage, + prelude.writeRef()); + if (prelude) + { + builder.append(prelude); + } + + // TODO: Downstream compilers (specifically dxc) can currently #include additional + // dependencies. This is currently the case for NVAPI headers included in the prelude. These + // dependencies are currently not picked up by the shader cache which is a significant + // issue. This can only be fixed by running the preprocessor in the slang compiler so dxc + // (or any other downstream compiler for that matter) isn't resolving any includes + // implicitly. + + // Add the downstream compiler version (if it exists) to the hash + auto downstreamCompiler = + getSessionImpl()->getOrLoadDownstreamCompiler(passThroughMode, nullptr); + if (downstreamCompiler) + { + ComPtr<ISlangBlob> versionString; + if (SLANG_SUCCEEDED(downstreamCompiler->getVersionString(versionString.writeRef()))) + { + builder.append(versionString); + } + } + }; + + // Add the target specified by targetIndex + if (targetIndex == -1) + { + // -1 means all targets. + for (auto targetReq : targets) + { + addTargetDigest(targetReq); + } + } + else + { + auto targetReq = targets[targetIndex]; + addTargetDigest(targetReq); + } +} + +SlangResult Linkage::addSearchPath(char const* path) +{ + m_optionSet.add(CompilerOptionName::Include, String(path)); + return SLANG_OK; +} + +SlangResult Linkage::addPreprocessorDefine(char const* name, char const* value) +{ + CompilerOptionValue val; + val.kind = CompilerOptionValueKind::String; + val.stringValue = name; + val.stringValue2 = value; + m_optionSet.add(CompilerOptionName::MacroDefine, val); + return SLANG_OK; +} + +SlangResult Linkage::setMatrixLayoutMode(SlangMatrixLayoutMode mode) +{ + m_optionSet.setMatrixLayoutMode((MatrixLayoutMode)mode); + return SLANG_OK; +} + +SlangResult Linkage::loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob) +{ + outPathInfo.type = PathInfo::Type::Unknown; + + SLANG_RETURN_ON_FAIL(m_fileSystemExt->loadFile(path.getBuffer(), outBlob)); + + ComPtr<ISlangBlob> uniqueIdentity; + // Get the unique identity + if (SLANG_FAILED( + m_fileSystemExt->getFileUniqueIdentity(path.getBuffer(), uniqueIdentity.writeRef()))) + { + // We didn't get a unique identity, so go with just a found path + outPathInfo.type = PathInfo::Type::FoundPath; + outPathInfo.foundPath = path; + } + else + { + outPathInfo = PathInfo::makeNormal(path, StringUtil::getString(uniqueIdentity)); + } + return SLANG_OK; +} + +Expr* Linkage::parseTermString(String typeStr, Scope* scope) +{ + // Create a SourceManager on the stack, so any allocations for 'SourceFile'/'SourceView' etc + // will be cleaned up + SourceManager localSourceManager; + localSourceManager.initialize(getSourceManager(), nullptr); + + Slang::SourceFile* srcFile = + localSourceManager.createSourceFileWithString(PathInfo::makeTypeParse(), typeStr); + + // We'll use a temporary diagnostic sink + DiagnosticSink sink(&localSourceManager, nullptr); + + // RAII type to make make sure current SourceManager is restored after parse. + // Use RAII - to make sure everything is reset even if an exception is thrown. + struct ScopeReplaceSourceManager + { + ScopeReplaceSourceManager(Linkage* linkage, SourceManager* replaceManager) + : m_linkage(linkage), m_originalSourceManager(linkage->getSourceManager()) + { + linkage->setSourceManager(replaceManager); + } + + ~ScopeReplaceSourceManager() { m_linkage->setSourceManager(m_originalSourceManager); } + + private: + Linkage* m_linkage; + SourceManager* m_originalSourceManager; + }; + + // We need to temporarily replace the SourceManager for this CompileRequest + ScopeReplaceSourceManager scopeReplaceSourceManager(this, &localSourceManager); + + SourceLanguage sourceLanguage = SourceLanguage::Slang; + SlangLanguageVersion languageVersion = m_optionSet.getLanguageVersion(); + + auto tokens = preprocessSource( + srcFile, + &sink, + nullptr, + Dictionary<String, String>(), + this, + sourceLanguage, + languageVersion); + + if (sourceLanguage == SourceLanguage::Unknown) + sourceLanguage = SourceLanguage::Slang; + + return parseTermFromSourceFile( + getASTBuilder(), + tokens, + &sink, + scope, + getNamePool(), + sourceLanguage); +} + +UInt Linkage::addTarget(CodeGenTarget target) +{ + RefPtr<TargetRequest> targetReq = new TargetRequest(this, target); + + Index result = targets.getCount(); + targets.add(targetReq); + return UInt(result); +} + +void Linkage::loadParsedModule( + RefPtr<FrontEndCompileRequest> compileRequest, + RefPtr<TranslationUnitRequest> translationUnit, + Name* name, + const PathInfo& pathInfo) +{ + // Note: we add the loaded module to our name->module listing + // before doing semantic checking, so that if it tries to + // recursively `import` itself, we can detect it. + // + RefPtr<Module> loadedModule = translationUnit->getModule(); + + // Get a path + String mostUniqueIdentity = pathInfo.getMostUniqueIdentity(); + SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); + + mapPathToLoadedModule.add(mostUniqueIdentity, loadedModule); + mapNameToLoadedModules.add(name, loadedModule); + + auto sink = translationUnit->compileRequest->getSink(); + + int errorCountBefore = sink->getErrorCount(); + int errorCountAfter; + try + { + compileRequest->checkAllTranslationUnits(); + } + catch (...) + { + mapPathToLoadedModule.remove(mostUniqueIdentity); + mapNameToLoadedModules.remove(name); + throw; + } + errorCountAfter = sink->getErrorCount(); + if (isInLanguageServer()) + { + // Don't generate IR as language server. + // This means that we currently cannot report errors that are detected during IR passes. + // Ideally we want to run those passes, but that is too risky for what it is worth right + // now. + } + else + { + if (errorCountAfter != errorCountBefore) + { + // There must have been an error in the loaded module. + // Remove from maps if there were errors during semantic checking + mapPathToLoadedModule.remove(mostUniqueIdentity); + mapNameToLoadedModules.remove(name); + } + else + { + // If we didn't run into any errors, then try to generate + // IR code for the imported module. + if (errorCountAfter == 0) + { + loadedModule->setIRModule( + generateIRForTranslationUnit(getASTBuilder(), translationUnit)); + } + } + } + loadedModulesList.add(loadedModule); +} + +RefPtr<Module> Linkage::findOrLoadSerializedModuleForModuleLibrary( + ISlangBlob* blobHoldingSerializedData, + ModuleChunk const* moduleChunk, + RIFF::ListChunk const* libraryChunk, + DiagnosticSink* sink) +{ + RefPtr<Module> resultModule; + + // We will attempt things in a few different steps, trying to + // decode as little of the serialized module as necessary at + // each step, so that we don't waste time on the heavyweight + // stuff when we didn't need to. + // + // The first step is to simply decode the module name, and + // see if we have a already loaded a matching module. + + auto moduleName = getNamePool()->getName(moduleChunk->getName()); + if (mapNameToLoadedModules.tryGetValue(moduleName, resultModule)) + return resultModule; + + // It is possible that the module has been loaded, but somehow + // under a different name, so next we decode the list of file + // paths that the module depends on, and then rely on the assumption + // that the first of those paths represents the file for the module + // itself to detect if we've already loaded a module from that + // path. + // + // Note: While this is a distasteful assumption to make, it is + // one that gets made in several parts of the compiler codebase + // already. It isn't something that can be fixed in just one + // place at this point. + + auto fileDependenciesList = moduleChunk->getFileDependencies(); + auto firstFileDependencyChunk = fileDependenciesList.getFirst(); + if (!firstFileDependencyChunk) + return nullptr; + + auto modulePathInfo = PathInfo::makePath(firstFileDependencyChunk->getValue()); + if (mapPathToLoadedModule.tryGetValue(modulePathInfo.getMostUniqueIdentity(), resultModule)) + return resultModule; + + // If we failed to find a previously-loaded module, then we + // will go ahead and load the module from the serialized form. + // + PathInfo filePathInfo; + return loadSerializedModule( + moduleName, + modulePathInfo, + blobHoldingSerializedData, + moduleChunk, + libraryChunk, + SourceLoc(), + sink); +} + +RefPtr<Module> Linkage::loadSerializedModule( + Name* moduleName, + const PathInfo& moduleFilePathInfo, + ISlangBlob* blobHoldingSerializedData, + ModuleChunk const* moduleChunk, + RIFF::ListChunk const* containerChunk, + SourceLoc const& requestingLoc, + DiagnosticSink* sink) +{ + auto astBuilder = getASTBuilder(); + SLANG_AST_BUILDER_RAII(astBuilder); + + auto module = RefPtr(new Module(this, astBuilder)); + module->setName(moduleName); + + // Just as if we were processing an `import` declaration in + // source code, we will track the fact that this serialized + // modlue is (effectively) being imported, so that we can + // diagnose anything troublesome, like an attempt at a + // recursive import. + // + ModuleBeingImportedRAII moduleBeingImported(this, module, moduleName, requestingLoc); + + // We will register the module in our data structures to + // track loaded modules, and then remove it in the case + // where there is some kind of failure. + // + String mostUniqueIdentity = moduleFilePathInfo.getMostUniqueIdentity(); + SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); + + mapPathToLoadedModule.add(mostUniqueIdentity, module); + mapNameToLoadedModules.add(moduleName, module); + try + { + if (SLANG_FAILED(loadSerializedModuleContents( + module, + moduleFilePathInfo, + blobHoldingSerializedData, + moduleChunk, + containerChunk, + sink))) + { + mapPathToLoadedModule.remove(mostUniqueIdentity); + mapNameToLoadedModules.remove(moduleName); + return nullptr; + } + + loadedModulesList.add(module); + return module; + } + catch (...) + { + mapPathToLoadedModule.remove(mostUniqueIdentity); + mapNameToLoadedModules.remove(moduleName); + throw; + } +} + +RefPtr<Module> Linkage::loadBinaryModuleImpl( + Name* moduleName, + const PathInfo& moduleFilePathInfo, + ISlangBlob* moduleFileContents, + SourceLoc const& requestingLoc, + DiagnosticSink* sink) +{ + auto astBuilder = getASTBuilder(); + SLANG_AST_BUILDER_RAII(astBuilder); + + // We start by reading the content of the file as + // an in-memory RIFF container. + // + auto rootChunk = RIFF::RootChunk::getFromBlob(moduleFileContents); + if (!rootChunk) + { + return nullptr; + } + + auto moduleChunk = ModuleChunk::find(rootChunk); + if (!moduleChunk) + { + return nullptr; + } + + // Next, we attempt to check if the binary module is up to + // date with the compilation options in use as well as + // the contents of all the files its compilation depended + // on (as determined by its hash). + // + String mostUniqueIdentity = moduleFilePathInfo.getMostUniqueIdentity(); + SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); + if (m_optionSet.getBoolOption(CompilerOptionName::UseUpToDateBinaryModule)) + { + if (!isBinaryModuleUpToDate(moduleFilePathInfo.foundPath, moduleChunk)) + { + return nullptr; + } + } + + // If everything seems reasonable, then we will go ahead and load + // the module more completely from that serialized representation. + // + RefPtr<Module> module = loadSerializedModule( + moduleName, + moduleFilePathInfo, + moduleFileContents, + moduleChunk, + rootChunk, + requestingLoc, + sink); + + return module; +} + +void Linkage::_diagnoseErrorInImportedModule(DiagnosticSink* sink) +{ + for (auto info = m_modulesBeingImported; info; info = info->next) + { + sink->diagnose(info->importLoc, Diagnostics::errorInImportedModule, info->name); + } + if (!isInLanguageServer()) + { + sink->diagnose(SourceLoc(), Diagnostics::complationCeased); + } +} + +RefPtr<Module> Linkage::loadModuleImpl( + Name* moduleName, + const PathInfo& modulePathInfo, + ISlangBlob* moduleBlob, + SourceLoc const& requestingLoc, + DiagnosticSink* sink, + const LoadedModuleDictionary* additionalLoadedModules, + ModuleBlobType blobType) +{ + switch (blobType) + { + case ModuleBlobType::IR: + return loadBinaryModuleImpl(moduleName, modulePathInfo, moduleBlob, requestingLoc, sink); + + case ModuleBlobType::Source: + return loadSourceModuleImpl( + moduleName, + modulePathInfo, + moduleBlob, + requestingLoc, + sink, + additionalLoadedModules); + + default: + SLANG_UNEXPECTED("unknown module blob type"); + UNREACHABLE_RETURN(nullptr); + } +} + +RefPtr<Module> Linkage::loadSourceModuleImpl( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* sourceBlob, + SourceLoc const& srcLoc, + DiagnosticSink* sink, + const LoadedModuleDictionary* additionalLoadedModules) +{ + RefPtr<FrontEndCompileRequest> frontEndReq = new FrontEndCompileRequest(this, nullptr, sink); + + frontEndReq->additionalLoadedModules = additionalLoadedModules; + + RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(frontEndReq); + translationUnit->compileRequest = frontEndReq; + translationUnit->setModuleName(name); + Stage impliedStage; + translationUnit->sourceLanguage = SourceLanguage::Slang; + + // If we are loading from a file with apparaent glsl extension, + // set the source language to GLSL to enable GLSL compatibility mode. + if ((SourceLanguage)findSourceLanguageFromPath(filePathInfo.getName(), impliedStage) == + SourceLanguage::GLSL) + { + translationUnit->sourceLanguage = SourceLanguage::GLSL; + } + + frontEndReq->addTranslationUnit(translationUnit); + + auto module = translationUnit->getModule(); + + ModuleBeingImportedRAII moduleBeingImported(this, module, name, srcLoc); + + // Create an artifact for the source + auto sourceArtifact = ArtifactUtil::createArtifact( + ArtifactDesc::make(ArtifactKind::Source, ArtifactPayload::Slang, ArtifactStyle::Unknown)); + + if (sourceBlob) + { + // If the user has already provided a source blob, use that. + sourceArtifact->addRepresentation( + new SourceBlobWithPathInfoArtifactRepresentation(filePathInfo, sourceBlob)); + } + else if ( + filePathInfo.type == PathInfo::Type::Normal || + filePathInfo.type == PathInfo::Type::FoundPath) + { + // Create with the 'friendly' name + // We create that it was loaded from the file system + sourceArtifact->addRepresentation(new ExtFileArtifactRepresentation( + filePathInfo.foundPath.getUnownedSlice(), + getFileSystemExt())); + } + else + { + return nullptr; + } + + translationUnit->addSourceArtifact(sourceArtifact); + + if (SLANG_FAILED(translationUnit->requireSourceFiles())) + { + // Some problem accessing source files + return nullptr; + } + int errorCountBefore = sink->getErrorCount(); + frontEndReq->parseTranslationUnit(translationUnit); + int errorCountAfter = sink->getErrorCount(); + + if (errorCountAfter != errorCountBefore && !isInLanguageServer()) + { + _diagnoseErrorInImportedModule(sink); + // Something went wrong during the parsing, so we should bail out. + return nullptr; + } + + try + { + loadParsedModule(frontEndReq, translationUnit, name, filePathInfo); + } + catch (const Slang::AbortCompilationException&) + { + // Something is fatally wrong, we should return nullptr. + module = nullptr; + } + errorCountAfter = sink->getErrorCount(); + + if (errorCountAfter != errorCountBefore && !isInLanguageServer()) + { + // If something is fatally wrong, we want to report + // the diagnostic even if we are in language server + // and processing a different module. + _diagnoseErrorInImportedModule(sink); + // Something went wrong during the parsing, so we should bail out. + return nullptr; + } + + if (!module) + return nullptr; + + module->setPathInfo(filePathInfo); + return module; +} + +bool Linkage::isBeingImported(Module* module) +{ + for (auto ii = m_modulesBeingImported; ii; ii = ii->next) + { + if (module == ii->module) + return true; + } + return false; +} + +// Derive a file name for the module, by taking the given +// identifier, replacing all occurrences of `_` with `-`, +// and then appending `.slang`. +// +// For example, `foo_bar` becomes `foo-bar.slang`. +String getFileNameFromModuleName(Name* name, bool translateUnderScore) +{ + String fileName; + if (!getText(name).getUnownedSlice().endsWithCaseInsensitive(".slang")) + { + StringBuilder sb; + for (auto c : getText(name)) + { + if (translateUnderScore && c == '_') + c = '-'; + + sb.append(c); + } + sb.append(".slang"); + fileName = sb.produceString(); + } + else + { + fileName = getText(name); + } + return fileName; +} + +RefPtr<Module> Linkage::findOrImportModule( + Name* moduleName, + SourceLoc const& requestingLoc, + DiagnosticSink* sink, + const LoadedModuleDictionary* loadedModules) +{ + // Have we already loaded a module matching this name? + // + RefPtr<LoadedModule> previouslyLoadedModule; + if (mapNameToLoadedModules.tryGetValue(moduleName, previouslyLoadedModule)) + { + // If the map shows a null module having been loaded, + // then that means there was a prior load attempt, + // but it failed, so we won't bother trying again. + // + if (!previouslyLoadedModule) + return nullptr; + + // If state shows us that the module is already being + // imported deeper on the call stack, then we've + // hit a recursive case, and that is an error. + // + if (isBeingImported(previouslyLoadedModule)) + { + // We seem to be in the middle of loading this module + sink->diagnose(requestingLoc, Diagnostics::recursiveModuleImport, moduleName); + return nullptr; + } + + return previouslyLoadedModule; + } + + // If the user is providing an additional list of loaded modules, we find + // if the module being imported is in that list. This allows a translation + // unit to use previously checked translation units in the same + // FrontEndCompileRequest. + { + Module* previouslyLoadedLocalModule = nullptr; + if (loadedModules && loadedModules->tryGetValue(moduleName, previouslyLoadedLocalModule)) + { + return previouslyLoadedLocalModule; + } + } + + // If the name being requested matches the name of a built-in module, + // then we will special-case the process by loading that builtin + // module directly. + // + // TODO: right now this logic is only considering the built-in `glsl` + // module, but it should probably be generalized so that we can more + // easily support having multiple built-in modules rather than just + // putting everything into `core`. + // + if (moduleName == getSessionImpl()->glslModuleName) + { + // This is a builtin glsl module, just load it from embedded definition. + auto glslModule = getSessionImpl()->getBuiltinModule(slang::BuiltinModuleName::GLSL); + if (!glslModule) + { + // Note: the way this logic is currently written, if the built-in + // `glsl` module fails to load, then we will *not* fall back to + // searching for a user-defined module in a file like `glsl.slang`. + // + // It is unclear if this should be the default behavior or not. + // Should built-in modules be prioritized over user modules? + // Should built-in modules shadow user modules, even when the + // built-in module fails to load, for some reason? + // + sink->diagnose(requestingLoc, Diagnostics::glslModuleNotAvailable, moduleName); + } + return glslModule; + } + + // We are going to use a loop to search for a suitable file to + // load the module from, to account for a few key choices: + // + // * We can both load modules from a source `.slang` file, + // or from a binary `.slang-module` file. + // + // * For a variety of reasons, the `import` logic has historically + // translated underscores in a module name into dashes (so that + // `import my_module` will look for `my-module.slang`), and we + // try to support both that convention as well as a convention + // that preserves underscores. + // + // To try to keep this logic as orthogonal as possible, we first + // construct lists of the options we want to iterate over, and + // then do the actual loop later. + + ShortList<ModuleBlobType, 2> typesToTry; + if (isInLanguageServer()) + { + // When in language server, we always prefer to use source module if it is available. + typesToTry.add(ModuleBlobType::Source); + typesToTry.add(ModuleBlobType::IR); + } + else + { + // Look for a precompiled module first, if not exist, load from source. + typesToTry.add(ModuleBlobType::IR); + typesToTry.add(ModuleBlobType::Source); + } + + // We will always search for a file name that directly matches the + // module name as written first, and then search for one with + // underscores replaced by dashes. The latter is the original + // behavior that `import` provided, but it seems safest to prefer + // the exact name spelled in the user's code when there might + // actually be ambiguity. + // + auto defaultSourceFileName = getFileNameFromModuleName(moduleName, false); + auto alternativeSourceFileName = getFileNameFromModuleName(moduleName, true); + String sourceFileNamesToTry[] = {defaultSourceFileName, alternativeSourceFileName}; + + // We are going to look for the candidate file using the same + // logic that would be used for a preprocessor `#include`, + // so we set up the necessary state. + // + IncludeSystem includeSystem(&getSearchDirectories(), getFileSystemExt(), getSourceManager()); + + // Just like with a `#include`, the search will take into + // account the path to the file where the request to import + // this module came from (e.g. the source file with the + // `import` declaration), if such a path is available. + // + PathInfo requestingPathInfo = + getSourceManager()->getPathInfo(requestingLoc, SourceLocType::Actual); + + for (auto type : typesToTry) + { + for (auto sourceFileName : sourceFileNamesToTry) + { + // The `sourceFileName` will have the `.slang` extension, + // so if we are looking for a binary module, we need + // to change the extension we will look for. + // + String fileName; + switch (type) + { + case ModuleBlobType::Source: + fileName = sourceFileName; + break; + + case ModuleBlobType::IR: + fileName = Path::replaceExt(sourceFileName, "slang-module"); + break; + } + + // We now search for a file matching the desired name, + // using the same logic as for a `#include`. + // + // TODO: We might want to consider how to handle the case + // of an `import` with a relative path a little specially, + // since it could in theory be possible for two `.slang` + // files with the same base name to exist in different + // directories in a project, and we'd want file-relative + // `import`s to work for each, without having either one + // be able to "claim" the bare identifier of the base + // name for itself. + // + PathInfo filePathInfo; + if (SLANG_FAILED( + includeSystem.findFile(fileName, requestingPathInfo.foundPath, filePathInfo))) + { + // If we failed to find the file at this step, we + // will continue the search for our other options. + // + continue; + } + + // We will *again* search for a previously loaded module. + // + // It is possible that the same file will have been loaded + // as a module under two different module names. The easiest + // way for this to happen is if there are `import` declarations + // using both the underscore and dash conventions (e.g., both + // `import "my-module.slang"` and `import my_module`). + // + // This case may also arise if one file `import`s a module using + // just an identifier for its name, but another `import`s it + // using a path (e.g., `import "subdir/file.slang"`). + // + // No matter how the situation arises, we only want to have one + // copy of the "same" module loaded at a given time, so we + // will re-use the existing module if we find one here. + // + if (mapPathToLoadedModule.tryGetValue( + filePathInfo.getMostUniqueIdentity(), + previouslyLoadedModule)) + { + // TODO: If we find a previously-loaded module at this step, + // then we should probably register that module under the + // given `moduleName` in the map of loaded modules, so + // that subsequent `import`s using the same form will find it. + // + return previouslyLoadedModule; + } + + // Now we try to load the content of the file. + // + // If for some reason we could find a file at the + // given path, but for some reason couldn't *open* + // and *read* it, then we continue the search + // using whatever other candidate file names are left. + // + ComPtr<ISlangBlob> fileContents; + if (SLANG_FAILED(includeSystem.loadFile(filePathInfo, fileContents))) + { + continue; + } + + // If we found a real file and were able to load its contents, + // then we'll go ahead and try to load a module from it, + // whether by compiling it or decoding the binary. + // + auto module = loadModuleImpl( + moduleName, + filePathInfo, + fileContents, + requestingLoc, + sink, + loadedModules, + type); + + // If the attempt to load the module from the given path + // was successful, we go ahead and use it, without trying + // out any other options. + // + if (module) + return module; + } + } + + // If we tried out all of our candidate file names + // and failed with each of them, then we diagnose + // an error based on the original *source* file + // name. + // + // TODO: this should really be an error message + // that clearly states something like "no file + // suitable for module `whatever` was found + // and loaded. + // + // Ideally that error message would include whatever + // of the candidate file names from the loop above + // got furthest along in the process (or just a + // list of the file names that were tried, if + // nothing was even found via the include system). + // + sink->diagnose(requestingLoc, Diagnostics::cannotOpenFile, defaultSourceFileName); + + // If the attempt to import the module failed, then + // we will stick a null pointer into the map of loaded + // modules, so that subsequent attempts to load a module + // with this name will return null without having to + // go through all the above steps yet again. + // + mapNameToLoadedModules[moduleName] = nullptr; + return nullptr; +} + +SourceFile* Linkage::loadSourceFile(String pathFrom, String path) +{ + IncludeSystem includeSystem(&getSearchDirectories(), getFileSystemExt(), getSourceManager()); + ComPtr<slang::IBlob> blob; + PathInfo pathInfo; + SLANG_RETURN_NULL_ON_FAIL(includeSystem.findFile(path, pathFrom, pathInfo)); + SourceFile* sourceFile = nullptr; + SLANG_RETURN_NULL_ON_FAIL(includeSystem.loadFile(pathInfo, blob, sourceFile)); + return sourceFile; +} + +// Check if a serialized module is up-to-date with current compiler options and source files. +bool Linkage::isBinaryModuleUpToDate(String fromPath, RIFF::ListChunk const* baseChunk) +{ + auto moduleChunk = ModuleChunk::find(baseChunk); + if (!moduleChunk) + return false; + + SHA1::Digest existingDigest = moduleChunk->getDigest(); + + DigestBuilder<SHA1> digestBuilder; + auto version = String(getBuildTagString()); + digestBuilder.append(version); + m_optionSet.buildHash(digestBuilder); + + // Find the canonical path of the directory containing the module source file. + String moduleSrcPath = ""; + + auto dependencyChunks = moduleChunk->getFileDependencies(); + if (auto firstDependencyChunk = dependencyChunks.getFirst()) + { + moduleSrcPath = firstDependencyChunk->getValue(); + + IncludeSystem includeSystem( + &getSearchDirectories(), + getFileSystemExt(), + getSourceManager()); + PathInfo modulePathInfo; + if (SLANG_SUCCEEDED(includeSystem.findFile(moduleSrcPath, fromPath, modulePathInfo))) + { + moduleSrcPath = modulePathInfo.foundPath; + Path::getCanonical(moduleSrcPath, moduleSrcPath); + } + } + + for (auto dependencyChunk : dependencyChunks) + { + auto file = dependencyChunk->getValue(); + auto sourceFile = loadSourceFile(fromPath, file); + if (!sourceFile) + { + // If we cannot find the source file from `fromPath`, + // try again from the module's source file path. + if (dependencyChunks.getFirst()) + sourceFile = loadSourceFile(moduleSrcPath, file); + } + if (!sourceFile) + return false; + digestBuilder.append(sourceFile->getDigest()); + } + return digestBuilder.finalize() == existingDigest; +} + +SLANG_NO_THROW bool SLANG_MCALL +Linkage::isBinaryModuleUpToDate(const char* modulePath, slang::IBlob* binaryModuleBlob) +{ + auto rootChunk = RIFF::RootChunk::getFromBlob(binaryModuleBlob); + if (!rootChunk) + return false; + return isBinaryModuleUpToDate(modulePath, rootChunk); +} + +SourceFile* Linkage::findFile(Name* name, SourceLoc loc, IncludeSystem& outIncludeSystem) +{ + auto impl = [&](bool translateUnderScore) -> SourceFile* + { + auto fileName = getFileNameFromModuleName(name, translateUnderScore); + + // Next, try to find the file of the given name, + // using our ordinary include-handling logic. + + auto& searchDirs = getSearchDirectories(); + outIncludeSystem = IncludeSystem(&searchDirs, getFileSystemExt(), getSourceManager()); + + // Get the original path info + PathInfo pathIncludedFromInfo = getSourceManager()->getPathInfo(loc, SourceLocType::Actual); + PathInfo filePathInfo; + + ComPtr<ISlangBlob> fileContents; + + // We have to load via the found path - as that is how file was originally loaded + if (SLANG_FAILED( + outIncludeSystem.findFile(fileName, pathIncludedFromInfo.foundPath, filePathInfo))) + { + return nullptr; + } + // Otherwise, try to load it. + SourceFile* sourceFile; + if (SLANG_FAILED(outIncludeSystem.loadFile(filePathInfo, fileContents, sourceFile))) + { + return nullptr; + } + return sourceFile; + }; + if (auto rs = impl(false)) + return rs; + return impl(true); +} + +Linkage::IncludeResult Linkage::findAndIncludeFile( + Module* module, + TranslationUnitRequest* translationUnit, + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink) +{ + IncludeResult result; + result.fileDecl = nullptr; + result.isNew = false; + + IncludeSystem includeSystem; + auto sourceFile = findFile(name, loc, includeSystem); + if (!sourceFile) + { + sink->diagnose(loc, Diagnostics::cannotOpenFile, getText(name)); + return result; + } + + // If the file has already been included, don't need to do anything further. + if (auto existingFileDecl = module->getIncludedSourceFileMap().tryGetValue(sourceFile)) + { + result.fileDecl = *existingFileDecl; + result.isNew = false; + return result; + } + + if (isInLanguageServer()) + { + // HACK: When in language server mode, we will always load the currently opend file as a + // fresh module even if some previously opened file already references the current file via + // `import` or `include`. see comments in `WorkspaceVersion::getOrLoadModule()` for the + // reason behind this. An undesired outcome of this decision is that we could endup + // including the currently opened file itself via chain of `__include`s because the + // currently opened file will not have a true unique file system identity that allows it to + // be deduplicated correct. Therefore we insert a hack logic here to detect re-inclusion by + // just the file path. We can clean up this hack by making the language server truly support + // incremental checking so we can reuse the previously loaded module instead of needing to + // always start with a fresh copy. + // + for (auto file : translationUnit->getSourceFiles()) + { + if (file->getPathInfo().hasFoundPath() && + Path::equals(file->getPathInfo().foundPath, sourceFile->getPathInfo().foundPath)) + return result; + } + } + + module->addFileDependency(sourceFile); + + // Create a transparent FileDecl to hold all children from the included file. + auto fileDecl = module->getASTBuilder()->create<FileDecl>(); + fileDecl->nameAndLoc.name = name; + fileDecl->parentDecl = module->getModuleDecl(); + module->getIncludedSourceFileMap().add(sourceFile, fileDecl); + + FrontEndPreprocessorHandler preprocessorHandler( + module, + module->getASTBuilder(), + sink, + translationUnit); + auto combinedPreprocessorDefinitions = translationUnit->getCombinedPreprocessorDefinitions(); + SourceLanguage sourceLanguage = translationUnit->sourceLanguage; + SlangLanguageVersion slangLanguageVersion = module->getModuleDecl()->languageVersion; + auto tokens = preprocessSource( + sourceFile, + sink, + &includeSystem, + combinedPreprocessorDefinitions, + this, + sourceLanguage, + slangLanguageVersion, + &preprocessorHandler); + + if (sourceLanguage == SourceLanguage::Unknown) + sourceLanguage = translationUnit->sourceLanguage; + + if (slangLanguageVersion != module->getModuleDecl()->languageVersion) + { + sink->diagnose( + tokens.begin()->getLoc(), + Diagnostics::languageVersionDiffersFromIncludingModule); + } + + auto outerScope = module->getModuleDecl()->ownedScope; + parseSourceFile( + module->getASTBuilder(), + translationUnit, + sourceLanguage, + tokens, + sink, + outerScope, + fileDecl); + + module->getModuleDecl()->addMember(fileDecl); + + result.fileDecl = fileDecl; + result.isNew = true; + return result; +} + +void Linkage::setFileSystem(ISlangFileSystem* inFileSystem) +{ + // Set the fileSystem + m_fileSystem = inFileSystem; + + // Release what's there + m_fileSystemExt.setNull(); + + // If nullptr passed in set up default + if (inFileSystem == nullptr) + { + m_fileSystemExt = new Slang::CacheFileSystem(Slang::OSFileSystem::getExtSingleton()); + } + else + { + if (auto cacheFileSystem = as<CacheFileSystem>(inFileSystem)) + { + m_fileSystemExt = cacheFileSystem; + } + else + { + if (m_requireCacheFileSystem) + { + m_fileSystemExt = new Slang::CacheFileSystem(inFileSystem); + } + else + { + // See if we have the full ISlangFileSystemExt interface, if we do just use it + inFileSystem->queryInterface(SLANG_IID_PPV_ARGS(m_fileSystemExt.writeRef())); + + // If not wrap with CacheFileSystem that emulates ISlangFileSystemExt from the + // ISlangFileSystem interface + if (!m_fileSystemExt) + { + // Construct a wrapper to emulate the extended interface behavior + m_fileSystemExt = new Slang::CacheFileSystem(m_fileSystem); + } + } + } + } + + // If requires a cache file system, check that it does have one + SLANG_ASSERT(m_requireCacheFileSystem == false || as<CacheFileSystem>(m_fileSystemExt)); + + // Set the file system used on the source manager + getSourceManager()->setFileSystemExt(m_fileSystemExt); +} + +SlangResult Linkage::loadSerializedModuleContents( + Module* module, + const PathInfo& moduleFilePathInfo, + ISlangBlob* blobHoldingSerializedData, + ModuleChunk const* moduleChunk, + RIFF::ListChunk const* containerChunk, + DiagnosticSink* sink) +{ + // At this point we've dealt with basically all of + // the formalities, and we just need to get down + // to the real work of decoding the information + // in the `moduleChunk`. + + // + // TODO(tfoley): The fact that a separate `containerChunk` is getting + // passed in here is entirely byproduct of the support for "module libraries" + // that can (in principle) contain multiple serialized modules. When + // things are serialized in the "container" representation used for + // a module library, there is a single `DebugChunk` as a child of + // the container, with all of the `ModuleChunk`s sharing that debug info. + // + // In contrast, the more typical kind of serialized module that the compiler + // produces serializes a single `ModuleChunk`, and the `DebugChunk` is + // one of its direct children. Thus there are currently two different + // locations where debug information might be found. + // + // Prior to the change where we navigate the serialized RIFF hierarchy + // in memory without copying it, this issue was addressed by having + // the subroutine that looked for a `DebugChunk` start at the `ModuleChunk` + // and work its way up through the hierarchy using parent pointers that + // were created as part of RIFF loading. When navigating the RIFF in-place + // we don't have such parent pointers. + // + // As a short-term solution, we should deprecate and remove the support + // for "module libraries" so that the code doesn't have to handle two + // different layouts. + // + // In the longer term, we should be making some conscious design decisions + // around how we want to organize the top-level structure of our serialized + // intermediate/output formats, since there's quite a mix of different + // approaches currently in use. + // + + auto sourceManager = getSourceManager(); + RefPtr<SerialSourceLocReader> sourceLocReader; + if (auto debugChunk = DebugChunk::find(moduleChunk, containerChunk)) + { + SLANG_RETURN_ON_FAIL( + readSourceLocationsFromDebugChunk(debugChunk, sourceManager, sourceLocReader)); + } + + auto astChunk = moduleChunk->findAST(); + if (!astChunk) + return SLANG_FAIL; + + auto irChunk = moduleChunk->findIR(); + if (!irChunk) + return SLANG_FAIL; + + auto astBuilder = getASTBuilder(); + auto session = getSessionImpl(); + + // For the purposes of any modules referenced + // by the module we're about to decode, we will + // construct a source location that represents + // the module itself (if possible). + // + // TODO(tfoley): This logic seems like overkill, given + // that many (most? all?) control-flow paths that can + // reach this routine will have already found a `SourceFile` + // to represent the module, as part of even getting the + // `moduleFilePathInfo` to pass in + // + // The approach here is more or less exactly copied + // from what the old `SerialContainerUtil::read` function + // used to do, with the hopes that it will as many tests + // passing as possible. + // + // Down the line somebody should scrutinize all of this + // kind of logic in the compiler codebase, because there + // is something that feels unclean about how paths are being handled. + // + SourceLoc serializedModuleLoc; + { + auto sourceFile = + sourceManager->findSourceFileByPathRecursively(moduleFilePathInfo.foundPath); + if (!sourceFile) + { + sourceFile = sourceManager->createSourceFileWithString(moduleFilePathInfo, String()); + sourceManager->addSourceFile(moduleFilePathInfo.getMostUniqueIdentity(), sourceFile); + } + auto sourceView = + sourceManager->createSourceView(sourceFile, &moduleFilePathInfo, SourceLoc()); + serializedModuleLoc = sourceView->getRange().begin; + } + + auto moduleDecl = readSerializedModuleAST( + this, + astBuilder, + sink, + blobHoldingSerializedData, + astChunk, + sourceLocReader, + serializedModuleLoc); + if (!moduleDecl) + return SLANG_FAIL; + module->setModuleDecl(moduleDecl); + + RefPtr<IRModule> irModule; + SLANG_RETURN_ON_FAIL(readSerializedModuleIR(irChunk, session, sourceLocReader, irModule)); + module->setIRModule(irModule); + + // The handling of file dependencies is complicated, because of + // the way that the encoding logic tried to make all of the + // paths be relative to the primary source file for the module. + // + // We end up needing to undo some amount of that work here. + // + + module->clearFileDependency(); + String moduleSourcePath = moduleFilePathInfo.foundPath; + bool isFirst = true; + for (auto depenencyFileChunk : moduleChunk->getFileDependencies()) + { + auto encodedDependencyFilePath = depenencyFileChunk->getValue(); + + auto sourceFile = loadSourceFile(moduleFilePathInfo.foundPath, encodedDependencyFilePath); + if (isFirst) + { + // The first file is the source for the main module file. + // We store the module path as the basis for finding the remaining + // dependent files. + if (sourceFile) + moduleSourcePath = sourceFile->getPathInfo().foundPath; + isFirst = false; + } + // If we cannot find the dependent file directly, try to find + // it relative to the module source path. + if (!sourceFile) + { + sourceFile = loadSourceFile(moduleSourcePath, encodedDependencyFilePath); + } + if (sourceFile) + { + module->addFileDependency(sourceFile); + } + } + module->setPathInfo(moduleFilePathInfo); + module->setDigest(moduleChunk->getDigest()); + module->_collectShaderParams(); + module->_discoverEntryPoints(sink, targets); + + // Hook up fileDecl's scope to module's scope. + for (auto fileDecl : moduleDecl->getDirectMemberDeclsOfType<FileDecl>()) + { + addSiblingScopeForContainerDecl(m_astBuilder, moduleDecl->ownedScope, fileDecl); + } + + return SLANG_OK; +} + +void Linkage::setRequireCacheFileSystem(bool requireCacheFileSystem) +{ + if (requireCacheFileSystem == m_requireCacheFileSystem) + { + return; + } + + ComPtr<ISlangFileSystem> scopeFileSystem(m_fileSystem); + m_requireCacheFileSystem = requireCacheFileSystem; + + setFileSystem(scopeFileSystem); +} + +RefPtr<Module> findOrImportModule( + Linkage* linkage, + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink, + const LoadedModuleDictionary* loadedModules) +{ + return linkage->findOrImportModule(name, loc, sink, loadedModules); +} + +Type* checkProperType(Linkage* linkage, TypeExp typeExp, DiagnosticSink* sink); + + +} // namespace Slang diff --git a/source/slang/slang-session.h b/source/slang/slang-session.h new file mode 100644 index 000000000..5b3405d60 --- /dev/null +++ b/source/slang/slang-session.h @@ -0,0 +1,476 @@ +// slang-session.h +#pragma once + +// +// This file declares the `Linkage` class, which implements +// the `slang::ISession` interface from the public API. +// +// TODO: there is an unfortunate and confusing situation +// where the public Slang API `ISession` type is implemented +// by the internal `Linkage` class, while the internal +// `Session` class implements the `IGlobalSession` interface +// from the public API. +// + +#include "../compiler-core/slang-artifact.h" +#include "../compiler-core/slang-command-line-args.h" +#include "../compiler-core/slang-include-system.h" +#include "../compiler-core/slang-name.h" +#include "../core/slang-riff.h" +#include "../core/slang-smart-pointer.h" +#include "slang-ast-base.h" +#include "slang-compiler-fwd.h" +#include "slang-compiler-options.h" +#include "slang-content-assist-info.h" +#include "slang-global-session.h" + +#include <slang.h> + +namespace Slang +{ + +/// A dictionary of modules to be considered when resolving `import`s, +/// beyond those that would normally be found through a `Linkage`. +/// +/// Checking of an `import` declaration will bottleneck through +/// `Linkage::findOrImportModule`, which would usually just check for +/// any module that had been previously loaded into the same `Linkage` +/// (e.g., by a call to `Linkage::loadModule()`). +/// +/// In the case where compilation is being done through an +/// explicit `FrontEndCompileRequest` or `EndToEndCompileRequest`, +/// the modules being compiled by that request do not get added to +/// the surrounding `Linkage`. +/// +/// There is a corner case when an explicit compile request has +/// multiple `TranslationUnitRequest`s, because the user (reasonably) +/// expects that if they compile `A.slang` and `B.slang` as two +/// distinct translation units in the same compile request, then +/// an `import B` inside of `A.slang` should resolve to reference +/// the code of `B.slang`. But because neither `A` nor `B` gets +/// added to the `Linkage`, and the `Linkage` is what usually +/// determines what is or isn't loaded, that intuition will +/// be wrong, without a bit of help. +/// +/// The `LoadedModuleDictionary` is thus filled in by a +/// `FrontEndCompileRequest` to collect the modules it is compiling, +/// so that they can cross-reference one another (albeit with +/// a current implementation restriction that modules in the +/// request can only `import` those earlier in the request...). +/// +/// The dictionary then gets passed around between nearly all of +/// the operations that deal with loading modules, to make sure +/// that they can detect a previously loaded module. +/// +typedef Dictionary<Name*, Module*> LoadedModuleDictionary; + +enum class ModuleBlobType +{ + Source, + IR +}; + +struct ContainerTypeKey +{ + slang::TypeReflection* elementType; + slang::ContainerType containerType; + bool operator==(ContainerTypeKey other) const + { + return elementType == other.elementType && containerType == other.containerType; + } + Slang::HashCode getHashCode() const + { + return Slang::combineHash( + Slang::getHashCode(elementType), + Slang::getHashCode(containerType)); + } +}; + +/// A context for loading and re-using code modules. +class Linkage : public RefObject, public slang::ISession +{ +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + CompilerOptionSet m_optionSet; + + ISlangUnknown* getInterface(const Guid& guid); + + SLANG_NO_THROW slang::IGlobalSession* SLANG_MCALL getGlobalSession() override; + SLANG_NO_THROW slang::IModule* SLANG_MCALL + loadModule(const char* moduleName, slang::IBlob** outDiagnostics = nullptr) override; + slang::IModule* loadModuleFromBlob( + const char* moduleName, + const char* path, + slang::IBlob* source, + ModuleBlobType blobType, + slang::IBlob** outDiagnostics = nullptr); + SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromIRBlob( + const char* moduleName, + const char* path, + slang::IBlob* source, + slang::IBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW SlangResult SLANG_MCALL loadModuleInfoFromIRBlob( + slang::IBlob* source, + SlangInt& outModuleVersion, + const char*& outModuleCompilerVersion, + const char*& outModuleName) override; + SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromSource( + const char* moduleName, + const char* path, + slang::IBlob* source, + slang::IBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromSourceString( + const char* moduleName, + const char* path, + const char* string, + slang::IBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW SlangResult SLANG_MCALL createCompositeComponentType( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + slang::IComponentType** outCompositeComponentType, + ISlangBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL specializeType( + slang::TypeReflection* type, + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getTypeLayout( + slang::TypeReflection* type, + SlangInt targetIndex = 0, + slang::LayoutRules rules = slang::LayoutRules::Default, + ISlangBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL getContainerType( + slang::TypeReflection* elementType, + slang::ContainerType containerType, + ISlangBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL getDynamicType() override; + SLANG_NO_THROW SlangResult SLANG_MCALL + getTypeRTTIMangledName(slang::TypeReflection* type, ISlangBlob** outNameBlob) override; + SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessMangledName( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + ISlangBlob** outNameBlob) override; + SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessSequentialID( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + uint32_t* outId) override; + SLANG_NO_THROW SlangResult SLANG_MCALL getDynamicObjectRTTIBytes( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + uint32_t* outBytes, + uint32_t bufferSize) override; + SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + slang::ITypeConformance** outConformance, + SlangInt conformanceIdOverride, + ISlangBlob** outDiagnostics) override; + SLANG_NO_THROW SlangResult SLANG_MCALL + createCompileRequest(SlangCompileRequest** outCompileRequest) override; + virtual SLANG_NO_THROW SlangInt SLANG_MCALL getLoadedModuleCount() override; + virtual SLANG_NO_THROW slang::IModule* SLANG_MCALL getLoadedModule(SlangInt index) override; + virtual SLANG_NO_THROW bool SLANG_MCALL + isBinaryModuleUpToDate(const char* modulePath, slang::IBlob* binaryModuleBlob) override; + + // Updates the supplied builder with linkage-related information, which includes preprocessor + // defines, the compiler version, and other compiler options. This is then merged with the hash + // produced for the program to produce a key that can be used with the shader cache. + void buildHash(DigestBuilder<SHA1>& builder, SlangInt targetIndex = -1); + + void addTarget(slang::TargetDesc const& desc); + SlangResult addSearchPath(char const* path); + SlangResult addPreprocessorDefine(char const* name, char const* value); + SlangResult setMatrixLayoutMode(SlangMatrixLayoutMode mode); + /// Create an initially-empty linkage + Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage); + + /// Dtor + ~Linkage(); + + bool isInLanguageServer() + { + return contentAssistInfo.checkingMode != ContentAssistCheckingMode::None; + } + + /// Get the parent session for this linkage + Session* getSessionImpl() { return m_session; } + + // Information on the targets we are being asked to + // generate code for. + List<RefPtr<TargetRequest>> targets; + + // Directories to search for `#include` files or `import`ed modules + SearchDirectoryList& getSearchDirectories(); + + // Source manager to help track files loaded + SourceManager m_defaultSourceManager; + SourceManager* m_sourceManager = nullptr; + RefPtr<CommandLineContext> m_cmdLineContext; + + // Used to store strings returned by the api as const char* + StringSlicePool m_stringSlicePool; + + // Name pool for looking up names + NamePool* namePool = nullptr; + + NamePool* getNamePool() { return namePool; } + + ASTBuilder* getASTBuilder() { return m_astBuilder; } + + RefPtr<ASTBuilder> m_astBuilder; + + // Cache for container types. + Dictionary<ContainerTypeKey, Type*> m_containerTypes; + + // cache used by type checking, implemented in check.cpp + TypeCheckingCache* getTypeCheckingCache(); + void destroyTypeCheckingCache(); + + RefPtr<RefObject> m_typeCheckingCache = nullptr; + + // Modules that have been dynamically loaded via `import` + // + // This is a list of unique modules loaded, in the order they were encountered. + List<RefPtr<LoadedModule>> loadedModulesList; + + // Map from the path (or uniqueIdentity if available) of a module file to its definition + Dictionary<String, RefPtr<LoadedModule>> mapPathToLoadedModule; + + // Map from the logical name of a module to its definition + Dictionary<Name*, RefPtr<LoadedModule>> mapNameToLoadedModules; + + // Map from the mangled name of RTTI objects to sequential IDs + // used by `switch`-based dynamic dispatch. + Dictionary<String, uint32_t> mapMangledNameToRTTIObjectIndex; + + // Counters for allocating sequential IDs to witness tables conforming to each interface type. + Dictionary<String, uint32_t> mapInterfaceMangledNameToSequentialIDCounters; + + SearchDirectoryList searchDirectoryCache; + + // The resulting specialized IR module for each entry point request + List<RefPtr<IRModule>> compiledModules; + + ContentAssistInfo contentAssistInfo; + + /// File system implementation to use when loading files from disk. + /// + /// If this member is `null`, a default implementation that tries + /// to use the native OS filesystem will be used instead. + /// + ComPtr<ISlangFileSystem> m_fileSystem; + + /// The extended file system implementation. Will be set to a default implementation + /// if fileSystem is nullptr. Otherwise it will either be fileSystem's interface, + /// or a wrapped impl that makes fileSystem operate as fileSystemExt + ComPtr<ISlangFileSystemExt> m_fileSystemExt; + + /// Get the currenly set file system + ISlangFileSystemExt* getFileSystemExt() { return m_fileSystemExt; } + + /// Load a file into memory using the configured file system. + /// + /// @param path The path to attempt to load from + /// @param outBlob A destination pointer to receive the loaded blob + /// @returns A `SlangResult` to indicate success or failure. + /// + SlangResult loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob); + + Expr* parseTermString(String str, Scope* scope); + + Type* specializeType( + Type* unspecializedType, + Int argCount, + Type* const* args, + DiagnosticSink* sink); + + /// Add a new target and return its index. + UInt addTarget(CodeGenTarget target); + + /// "Bottleneck" routine for loading a module. + /// + /// All attempts to load a module, whether through + /// Slang API calls, `import` operations, or other + /// means, should bottleneck through `loadModuleImpl`, + /// or one of the specialized cases `loadSourceModuleImpl` + /// and `loadBinaryModuleImpl`. + /// + RefPtr<Module> loadModuleImpl( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* fileContentsBlob, + SourceLoc const& loc, + DiagnosticSink* sink, + const LoadedModuleDictionary* additionalLoadedModules, + ModuleBlobType blobType); + + RefPtr<Module> loadSourceModuleImpl( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* fileContentsBlob, + SourceLoc const& loc, + DiagnosticSink* sink, + const LoadedModuleDictionary* additionalLoadedModules); + + RefPtr<Module> loadBinaryModuleImpl( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* fileContentsBlob, + SourceLoc const& loc, + DiagnosticSink* sink); + + /// Either finds a previously-loaded module matching what + /// was serialized into `moduleChunk`, or else attempts + /// to load the serialized module. + /// + /// If a previously-loaded module is found that matches the + /// name or path information in `moduleChunk`, then that + /// previously-loaded module is returned. + /// + /// Othwerise, attempts to load a module from `moduleChunk` + /// and, if successful, returns the freshly loaded module. + /// + /// Otherwise, return null. + /// + RefPtr<Module> findOrLoadSerializedModuleForModuleLibrary( + ISlangBlob* blobHoldingSerializedData, + ModuleChunk const* moduleChunk, + RIFF::ListChunk const* libraryChunk, + DiagnosticSink* sink); + + RefPtr<Module> loadSerializedModule( + Name* moduleName, + const PathInfo& moduleFilePathInfo, + ISlangBlob* blobHoldingSerializedData, + ModuleChunk const* moduleChunk, + RIFF::ListChunk const* containerChunk, //< The outer container, if there is one. + SourceLoc const& requestingLoc, + DiagnosticSink* sink); + + SlangResult loadSerializedModuleContents( + Module* module, + const PathInfo& moduleFilePathInfo, + ISlangBlob* blobHoldingSerializedData, + ModuleChunk const* moduleChunk, + RIFF::ListChunk const* containerChunk, //< The outer container, if there is one. + DiagnosticSink* sink); + + SourceFile* loadSourceFile(String pathFrom, String path); + + void loadParsedModule( + RefPtr<FrontEndCompileRequest> compileRequest, + RefPtr<TranslationUnitRequest> translationUnit, + Name* name, + PathInfo const& pathInfo); + + bool isBinaryModuleUpToDate(String fromPath, RIFF::ListChunk const* baseChunk); + + RefPtr<Module> findOrImportModule( + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink, + const LoadedModuleDictionary* loadedModules = nullptr); + + SourceFile* findFile(Name* name, SourceLoc loc, IncludeSystem& outIncludeSystem); + struct IncludeResult + { + FileDecl* fileDecl; + bool isNew; + }; + IncludeResult findAndIncludeFile( + Module* module, + TranslationUnitRequest* translationUnit, + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink); + + SourceManager* getSourceManager() { return m_sourceManager; } + + /// Override the source manager for the linkage. + /// + /// This is only used to install a temporary override when + /// parsing stuff from strings (where we don't want to retain + /// full source files for the parsed result). + /// + /// TODO: We should remove the need for this hack. + /// + void setSourceManager(SourceManager* sourceManager) { m_sourceManager = sourceManager; } + + void setRequireCacheFileSystem(bool requireCacheFileSystem); + + void setFileSystem(ISlangFileSystem* fileSystem); + + DeclRef<Decl> specializeGeneric( + DeclRef<Decl> declRef, + List<Expr*> argExprs, + DiagnosticSink* sink); + + DeclRef<Decl> specializeWithArgTypes( + Expr* funcExpr, + List<Type*> argTypes, + DiagnosticSink* sink); + + bool isSpecialized(DeclRef<Decl> declRef); + + DiagnosticSink::Flags diagnosticSinkFlags = 0; + + bool m_requireCacheFileSystem = false; + + // Modules that have been read in with the -r option + List<ComPtr<IArtifact>> m_libModules; + + void _stopRetainingParentSession() { m_retainedSession = nullptr; } + + // Get shared semantics information for reflection purposes. + SharedSemanticsContext* getSemanticsForReflection(); + +private: + /// The global Slang library session that this linkage is a child of + Session* m_session = nullptr; + + RefPtr<Session> m_retainedSession; + + /// Tracks state of modules currently being loaded. + /// + /// This information is used to diagnose cases where + /// a user tries to recursively import the same module + /// (possibly along a transitive chain of `import`s). + /// + struct ModuleBeingImportedRAII + { + public: + ModuleBeingImportedRAII( + Linkage* linkage, + Module* module, + Name* name, + SourceLoc const& importLoc) + : linkage(linkage), module(module), name(name), importLoc(importLoc) + { + next = linkage->m_modulesBeingImported; + linkage->m_modulesBeingImported = this; + } + + ~ModuleBeingImportedRAII() { linkage->m_modulesBeingImported = next; } + + Linkage* linkage; + Module* module; + Name* name; + SourceLoc importLoc; + ModuleBeingImportedRAII* next; + }; + + // Any modules currently being imported will be listed here + ModuleBeingImportedRAII* m_modulesBeingImported = nullptr; + + /// Is the given module in the middle of being imported? + bool isBeingImported(Module* module); + + /// Diagnose that an error occured in the process of importing a module + void _diagnoseErrorInImportedModule(DiagnosticSink* sink); + + List<Type*> m_specializedTypes; + + RefPtr<SharedSemanticsContext> m_semanticsForReflection; +}; +} // namespace Slang diff --git a/source/slang/slang-target-program.cpp b/source/slang/slang-target-program.cpp new file mode 100644 index 000000000..ffb859b55 --- /dev/null +++ b/source/slang/slang-target-program.cpp @@ -0,0 +1,113 @@ +// slang-target-program.cpp +#include "slang-target-program.h" + +#include "slang-compiler.h" +#include "slang-type-layout.h" + +namespace Slang +{ + +// +// TargetProgram +// + +TargetProgram::TargetProgram(ComponentType* componentType, TargetRequest* targetReq) + : m_program(componentType), m_targetReq(targetReq) +{ + m_entryPointResults.setCount(componentType->getEntryPointCount()); + m_optionSet.overrideWith(m_program->getOptionSet()); + m_optionSet.inheritFrom(targetReq->getOptionSet()); +} + +IArtifact* TargetProgram::_createWholeProgramResult( + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq) +{ + // We want to call `emitEntryPoints` function to generate code that contains + // all the entrypoints defined in `m_program`. + // The current logic of `emitEntryPoints` takes a list of entry-point indices to + // emit code for, so we construct such a list first. + List<Int> entryPointIndices; + + m_entryPointResults.setCount(m_program->getEntryPointCount()); + entryPointIndices.setCount(m_program->getEntryPointCount()); + for (Index i = 0; i < entryPointIndices.getCount(); i++) + entryPointIndices[i] = i; + + CodeGenContext::Shared sharedCodeGenContext(this, entryPointIndices, sink, endToEndReq); + CodeGenContext codeGenContext(&sharedCodeGenContext); + + if (SLANG_FAILED(codeGenContext.emitEntryPoints(m_wholeProgramResult))) + { + return nullptr; + } + + return m_wholeProgramResult; +} + +IArtifact* TargetProgram::_createEntryPointResult( + Int entryPointIndex, + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq) +{ + // It is possible that entry points got added to the `Program` + // *after* we created this `TargetProgram`, so there might be + // a request for an entry point that we didn't allocate space for. + // + // TODO: Change the construction logic so that a `Program` is + // constructed all at once rather than incrementally, to avoid + // this problem. + // + if (entryPointIndex >= m_entryPointResults.getCount()) + m_entryPointResults.setCount(entryPointIndex + 1); + + + CodeGenContext::EntryPointIndices entryPointIndices; + entryPointIndices.add(entryPointIndex); + + CodeGenContext::Shared sharedCodeGenContext(this, entryPointIndices, sink, endToEndReq); + CodeGenContext codeGenContext(&sharedCodeGenContext); + + codeGenContext.emitEntryPoints(m_entryPointResults[entryPointIndex]); + + return m_entryPointResults[entryPointIndex]; +} + +IArtifact* TargetProgram::getOrCreateWholeProgramResult(DiagnosticSink* sink) +{ + if (m_wholeProgramResult) + return m_wholeProgramResult; + + // If we haven't yet computed a layout for this target + // program, we need to make sure that is done before + // code generation. + // + if (!getOrCreateIRModuleForLayout(sink)) + { + return nullptr; + } + + return _createWholeProgramResult(sink); +} + +IArtifact* TargetProgram::getOrCreateEntryPointResult(Int entryPointIndex, DiagnosticSink* sink) +{ + if (entryPointIndex >= m_entryPointResults.getCount()) + m_entryPointResults.setCount(entryPointIndex + 1); + + if (IArtifact* artifact = m_entryPointResults[entryPointIndex]) + return artifact; + + // If we haven't yet computed a layout for this target + // program, we need to make sure that is done before + // code generation. + // + if (!getOrCreateIRModuleForLayout(sink)) + { + return nullptr; + } + + return _createEntryPointResult(entryPointIndex, sink); +} + +} // namespace Slang diff --git a/source/slang/slang-target-program.h b/source/slang/slang-target-program.h new file mode 100644 index 000000000..57a9c46fb --- /dev/null +++ b/source/slang/slang-target-program.h @@ -0,0 +1,143 @@ +// slang-target-program.h +#pragma once + +// +// This file declares the `TargetProgram` class, which is +// primarily used to cache generated target code for a +// linked program/binary and/or its entry points. +// + +#include "../core/slang-smart-pointer.h" +#include "slang-hlsl-to-vulkan-layout-options.h" +#include "slang-ir.h" +#include "slang-linkable.h" +#include "slang-target.h" + +namespace Slang +{ + +/// A `TargetProgram` represents a `ComponentType` specialized for a particular `TargetRequest` +/// +/// TODO: This should probably be renamed to `TargetComponentType`. +/// +/// By binding a component type to a specific target, a `TargetProgram` allows +/// for things like layout to be computed, that fundamentally depend on +/// the choice of target. +/// +/// A `TargetProgram` handles request for compiled kernel code for +/// entry point functions. In practice, kernel code can only be +/// correctly generated when the underlying `ComponentType` is "fully linked" +/// (has no remaining unsatisfied requirements). +/// +class TargetProgram : public RefObject +{ +public: + TargetProgram(ComponentType* componentType, TargetRequest* targetReq); + + /// Get the underlying program + ComponentType* getProgram() { return m_program; } + + /// Get the underlying target + TargetRequest* getTargetReq() { return m_targetReq; } + + /// Get the layout for the program on the target. + /// + /// If this is the first time the layout has been + /// requested, report any errors that arise during + /// layout to the given `sink`. + /// + ProgramLayout* getOrCreateLayout(DiagnosticSink* sink); + + /// Get the layout for the program on the target. + /// + /// This routine assumes that `getOrCreateLayout` + /// has already been called previously. + /// + ProgramLayout* getExistingLayout() + { + SLANG_ASSERT(m_layout); + return m_layout; + } + + /// Get the compiled code for an entry point on the target. + /// + /// If this is the first time that code generation has + /// been requested, report any errors that arise during + /// code generation to the given `sink`. + /// + IArtifact* getOrCreateEntryPointResult(Int entryPointIndex, DiagnosticSink* sink); + IArtifact* getOrCreateWholeProgramResult(DiagnosticSink* sink); + + IArtifact* getExistingWholeProgramResult() { return m_wholeProgramResult; } + /// Get the compiled code for an entry point on the target. + /// + /// This routine assumes that `getOrCreateEntryPointResult` + /// has already been called previously. + /// + IArtifact* getExistingEntryPointResult(Int entryPointIndex) + { + return m_entryPointResults[entryPointIndex]; + } + + IArtifact* _createWholeProgramResult( + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq = nullptr); + + /// Internal helper for `getOrCreateEntryPointResult`. + /// + /// This is used so that command-line and API-based + /// requests for code can bottleneck through the same place. + /// + /// Shouldn't be called directly by most code. + /// + IArtifact* _createEntryPointResult( + Int entryPointIndex, + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq = nullptr); + + RefPtr<IRModule> getOrCreateIRModuleForLayout(DiagnosticSink* sink); + + RefPtr<IRModule> getExistingIRModuleForLayout() { return m_irModuleForLayout; } + + CompilerOptionSet& getOptionSet() { return m_optionSet; } + + HLSLToVulkanLayoutOptions* getHLSLToVulkanLayoutOptions() + { + return m_targetReq->getHLSLToVulkanLayoutOptions(); + } + + bool shouldEmitSPIRVDirectly() + { + return isKhronosTarget(m_targetReq) && getOptionSet().shouldEmitSPIRVDirectly(); + } + +private: + RefPtr<IRModule> createIRModuleForLayout(DiagnosticSink* sink); + + // The program being compiled or laid out + ComponentType* m_program; + + // The target that code/layout will be generated for + TargetRequest* m_targetReq; + + // The computed layout, if it has been generated yet + RefPtr<ProgramLayout> m_layout; + + CompilerOptionSet m_optionSet; + + // Generated compile results for each entry point + // in the parent `Program` (indexing matches + // the order they are given in the `Program`) + ComPtr<IArtifact> m_wholeProgramResult; + List<ComPtr<IArtifact>> m_entryPointResults; + + RefPtr<IRModule> m_irModuleForLayout; +}; + +/// Given a target request returns which (if any) intermediate source language is required +/// to produce it. +/// +/// If no intermediate source language is required, will return SourceLanguage::Unknown +SourceLanguage getIntermediateSourceLanguageForTarget(TargetProgram* req); + +} // namespace Slang diff --git a/source/slang/slang-target.cpp b/source/slang/slang-target.cpp new file mode 100644 index 000000000..29430abde --- /dev/null +++ b/source/slang/slang-target.cpp @@ -0,0 +1,248 @@ +// slang-target.cpp +#include "slang-target.h" + +#include "../core/slang-type-text-util.h" +#include "compiler-core/slang-artifact-desc-util.h" +#include "slang-compiler.h" +#include "slang-type-layout.h" + +namespace Slang +{ + +bool isHeterogeneousTarget(CodeGenTarget target) +{ + return ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)).style == + ArtifactStyle::Host; +} + +void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) +{ + UnownedStringSlice name = TypeTextUtil::getCompileTargetName(asExternal(val)); + name = name.getLength() ? name : toSlice("<unknown>"); + sb << name; +} + +// +// TargetRequest +// + +TargetRequest::TargetRequest(Linkage* linkage, CodeGenTarget format) + : linkage(linkage) +{ + optionSet = linkage->m_optionSet; + optionSet.add(CompilerOptionName::Target, format); +} + +TargetRequest::TargetRequest(const TargetRequest& other) + : RefObject(), linkage(other.linkage), optionSet(other.optionSet) +{ +} + + +Session* TargetRequest::getSession() +{ + return linkage->getSessionImpl(); +} + +HLSLToVulkanLayoutOptions* TargetRequest::getHLSLToVulkanLayoutOptions() +{ + if (!hlslToVulkanOptions) + { + hlslToVulkanOptions = new HLSLToVulkanLayoutOptions(); + hlslToVulkanOptions->loadFromOptionSet(optionSet); + } + return hlslToVulkanOptions.get(); +} + +void TargetRequest::setTargetCaps(CapabilitySet capSet) +{ + cookedCapabilities = capSet; +} + +CapabilitySet TargetRequest::getTargetCaps() +{ + if (!cookedCapabilities.isEmpty()) + return cookedCapabilities; + + // The full `CapabilitySet` for the target will be computed + // from the combination of the code generation format, and + // the profile. + // + // Note: the preofile might have been set in a way that is + // inconsistent with the output code format of SPIR-V, but + // a profile of Direct3D Shader Model 5.1. In those cases, + // the format should always override the implications in + // the profile. + // + // TODO: This logic isn't currently taking int account + // the information in the profile, because the current + // `CapabilityAtom`s that we support don't include any + // of the details there (e.g., the shader model versions). + // + // Eventually, we'd want to have a rich set of capability + // atoms, so that most of the information about what operations + // are available where can be directly encoded on the declarations. + + List<CapabilityName> atoms; + + // If the user specified a explicit profile, we should pull + // a corresponding atom representing the target version from the profile. + CapabilitySet profileCaps = optionSet.getProfile().getCapabilityName(); + + bool isGLSLTarget = false; + switch (getTarget()) + { + case CodeGenTarget::GLSL: + isGLSLTarget = true; + atoms.add(CapabilityName::glsl); + break; + case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: + if (getOptionSet().shouldEmitSPIRVDirectly()) + { + // Default to SPIRV 1.5 if the user has not specified a target version. + bool hasTargetVersionAtom = false; + if (!profileCaps.isEmpty()) + { + profileCaps.join(CapabilitySet(CapabilityName::spirv_1_0)); + for (auto profileCapAtomSet : profileCaps.getAtomSets()) + { + for (auto atom : profileCapAtomSet) + { + if (isTargetVersionAtom(asAtom(atom))) + { + atoms.add((CapabilityName)atom); + hasTargetVersionAtom = true; + } + } + } + } + if (!hasTargetVersionAtom) + { + atoms.add(CapabilityName::spirv_1_5); + } + // If the user specified any SPIR-V extensions in the profile, + // pull them in. + for (auto profileCapAtomSet : profileCaps.getAtomSets()) + { + for (auto atom : profileCapAtomSet) + { + if (isSpirvExtensionAtom(asAtom(atom))) + { + atoms.add((CapabilityName)atom); + hasTargetVersionAtom = true; + } + } + } + } + else + { + isGLSLTarget = true; + atoms.add(CapabilityName::glsl); + profileCaps.addSpirvVersionFromOtherAsGlslSpirvVersion(profileCaps); + } + break; + + case CodeGenTarget::HLSL: + case CodeGenTarget::DXBytecode: + case CodeGenTarget::DXBytecodeAssembly: + case CodeGenTarget::DXIL: + case CodeGenTarget::DXILAssembly: + atoms.add(CapabilityName::hlsl); + break; + + case CodeGenTarget::CSource: + atoms.add(CapabilityName::c); + break; + + case CodeGenTarget::CPPSource: + case CodeGenTarget::PyTorchCppBinding: + case CodeGenTarget::HostExecutable: + case CodeGenTarget::ShaderSharedLibrary: + case CodeGenTarget::HostSharedLibrary: + case CodeGenTarget::HostHostCallable: + case CodeGenTarget::ShaderHostCallable: + atoms.add(CapabilityName::cpp); + break; + + case CodeGenTarget::CUDASource: + case CodeGenTarget::PTX: + atoms.add(CapabilityName::cuda); + break; + + case CodeGenTarget::Metal: + case CodeGenTarget::MetalLib: + case CodeGenTarget::MetalLibAssembly: + atoms.add(CapabilityName::metal); + break; + + case CodeGenTarget::WGSLSPIRV: + case CodeGenTarget::WGSLSPIRVAssembly: + case CodeGenTarget::WGSL: + atoms.add(CapabilityName::wgsl); + break; + + default: + break; + } + + CapabilitySet targetCap = CapabilitySet(atoms); + + if (profileCaps.atLeastOneSetImpliedInOther(targetCap) == + CapabilitySet::ImpliesReturnFlags::Implied) + targetCap.join(profileCaps); + + for (auto atomVal : optionSet.getArray(CompilerOptionName::Capability)) + { + CapabilitySet toAdd; + switch (atomVal.kind) + { + case CompilerOptionValueKind::Int: + toAdd = CapabilitySet(CapabilityName(atomVal.intValue)); + break; + case CompilerOptionValueKind::String: + toAdd = CapabilitySet(findCapabilityName(atomVal.stringValue.getUnownedSlice())); + break; + } + + if (isGLSLTarget) + targetCap.addSpirvVersionFromOtherAsGlslSpirvVersion(toAdd); + + if (!targetCap.isIncompatibleWith(toAdd)) + targetCap.join(toAdd); + } + + cookedCapabilities = targetCap; + + SLANG_ASSERT(!cookedCapabilities.isInvalid()); + + return cookedCapabilities; +} + + +TypeLayout* TargetRequest::getTypeLayout(Type* type, slang::LayoutRules rules) +{ + SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder()); + + // TODO: We are not passing in a `ProgramLayout` here, although one + // is nominally required to establish the global ordering of + // generic type parameters, which might be referenced from field types. + // + // The solution here is to make sure that the reflection data for + // uses of global generic/existential types does *not* include any + // kind of index in that global ordering, and just refers to the + // parameter instead (leaving the user to figure out how that + // maps to the ordering via some API on the program layout). + // + auto layoutContext = getInitialLayoutContextForTarget(this, nullptr, rules); + + RefPtr<TypeLayout> result; + auto key = TypeLayoutKey{type, rules}; + if (getTypeLayouts().tryGetValue(key, result)) + return result.Ptr(); + result = createTypeLayout(layoutContext, type); + getTypeLayouts()[key] = result; + return result.Ptr(); +} + +} // namespace Slang diff --git a/source/slang/slang-target.h b/source/slang/slang-target.h new file mode 100644 index 000000000..6ac52657f --- /dev/null +++ b/source/slang/slang-target.h @@ -0,0 +1,141 @@ +// slang-target.h +#pragma once + +// +// This file declares the `TargetRequest` class, which is +// the primary way that the Slang compiler groups together +// a compilation target and options that affect output +// code generation and/or layout for that target. +// + +#include "../core/slang-string.h" +#include "slang-ast-base.h" +#include "slang-compiler-fwd.h" +#include "slang-compiler-options.h" +#include "slang-hlsl-to-vulkan-layout-options.h" + +#include <slang.h> + +namespace Slang +{ + +enum class CodeGenTarget : SlangCompileTargetIntegral +{ + Unknown = SLANG_TARGET_UNKNOWN, + None = SLANG_TARGET_NONE, + GLSL = SLANG_GLSL, + HLSL = SLANG_HLSL, + SPIRV = SLANG_SPIRV, + SPIRVAssembly = SLANG_SPIRV_ASM, + DXBytecode = SLANG_DXBC, + DXBytecodeAssembly = SLANG_DXBC_ASM, + DXIL = SLANG_DXIL, + DXILAssembly = SLANG_DXIL_ASM, + CSource = SLANG_C_SOURCE, + CPPSource = SLANG_CPP_SOURCE, + PyTorchCppBinding = SLANG_CPP_PYTORCH_BINDING, + HostCPPSource = SLANG_HOST_CPP_SOURCE, + HostExecutable = SLANG_HOST_EXECUTABLE, + HostSharedLibrary = SLANG_HOST_SHARED_LIBRARY, + ShaderSharedLibrary = SLANG_SHADER_SHARED_LIBRARY, + ShaderHostCallable = SLANG_SHADER_HOST_CALLABLE, + CUDASource = SLANG_CUDA_SOURCE, + PTX = SLANG_PTX, + CUDAObjectCode = SLANG_CUDA_OBJECT_CODE, + ObjectCode = SLANG_OBJECT_CODE, + HostHostCallable = SLANG_HOST_HOST_CALLABLE, + Metal = SLANG_METAL, + MetalLib = SLANG_METAL_LIB, + MetalLibAssembly = SLANG_METAL_LIB_ASM, + WGSL = SLANG_WGSL, + WGSLSPIRVAssembly = SLANG_WGSL_SPIRV_ASM, + WGSLSPIRV = SLANG_WGSL_SPIRV, + HostVM = SLANG_HOST_VM, + CountOf = SLANG_TARGET_COUNT_OF, +}; + +bool isHeterogeneousTarget(CodeGenTarget target); + +void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val); + +class TargetRequest; + +/// Are we generating code for a D3D API? +bool isD3DTarget(TargetRequest* targetReq); + +// Are we generating code for Metal? +bool isMetalTarget(TargetRequest* targetReq); + +/// Are we generating code for a Khronos API (OpenGL or Vulkan)? +bool isKhronosTarget(TargetRequest* targetReq); +bool isKhronosTarget(CodeGenTarget target); + +/// Are we generating code for a CUDA API (CUDA / OptiX)? +bool isCUDATarget(TargetRequest* targetReq); + +// Are we generating code for a CPU target +bool isCPUTarget(TargetRequest* targetReq); + +/// Are we generating code for the WebGPU API? +bool isWGPUTarget(TargetRequest* targetReq); +bool isWGPUTarget(CodeGenTarget target); + +/// A request to generate output in some target format. +class TargetRequest : public RefObject +{ +public: + TargetRequest(Linkage* linkage, CodeGenTarget format); + + TargetRequest(const TargetRequest& other); + + Linkage* getLinkage() { return linkage; } + + Session* getSession(); + + CodeGenTarget getTarget() + { + return optionSet.getEnumOption<CodeGenTarget>(CompilerOptionName::Target); + } + + // TypeLayouts created on the fly by reflection API + struct TypeLayoutKey + { + Type* type; + slang::LayoutRules rules; + HashCode getHashCode() const + { + Hasher hasher; + hasher.hashValue(type); + hasher.hashValue(rules); + return hasher.getResult(); + } + bool operator==(TypeLayoutKey other) const + { + return type == other.type && rules == other.rules; + } + }; + Dictionary<TypeLayoutKey, RefPtr<TypeLayout>> typeLayouts; + + Dictionary<TypeLayoutKey, RefPtr<TypeLayout>>& getTypeLayouts() { return typeLayouts; } + + TypeLayout* getTypeLayout(Type* type, slang::LayoutRules rules); + + CompilerOptionSet& getOptionSet() { return optionSet; } + + CapabilitySet getTargetCaps(); + + void setTargetCaps(CapabilitySet capSet); + + HLSLToVulkanLayoutOptions* getHLSLToVulkanLayoutOptions(); + +private: + Linkage* linkage = nullptr; + CompilerOptionSet optionSet; + CapabilitySet cookedCapabilities; + RefPtr<HLSLToVulkanLayoutOptions> hlslToVulkanOptions; +}; + +/// Are resource types "bindless" (implemented as ordinary data) on the given `target`? +bool areResourceTypesBindlessOnTarget(TargetRequest* target); + +} // namespace Slang diff --git a/source/slang/slang-translation-unit.cpp b/source/slang/slang-translation-unit.cpp new file mode 100644 index 000000000..049de01eb --- /dev/null +++ b/source/slang/slang-translation-unit.cpp @@ -0,0 +1,251 @@ +// slang-translation-unit.cpp +#include "slang-translation-unit.h" + +#include "slang-compiler.h" + +namespace Slang +{ + +// +// TranslationUnitRequest +// + +TranslationUnitRequest::TranslationUnitRequest(FrontEndCompileRequest* compileRequest) + : compileRequest(compileRequest) +{ + module = new Module(compileRequest->getLinkage()); +} + +TranslationUnitRequest::TranslationUnitRequest(FrontEndCompileRequest* compileRequest, Module* m) + : compileRequest(compileRequest), module(m), isChecked(true) +{ + moduleName = getNamePool()->getName(m->getName()); +} + +Session* TranslationUnitRequest::getSession() +{ + return compileRequest->getSession(); +} + +NamePool* TranslationUnitRequest::getNamePool() +{ + return compileRequest->getNamePool(); +} + +SourceManager* TranslationUnitRequest::getSourceManager() +{ + return compileRequest->getSourceManager(); +} + +Scope* TranslationUnitRequest::getLanguageScope() +{ + Scope* languageScope = nullptr; + switch (sourceLanguage) + { + case SourceLanguage::HLSL: + languageScope = getSession()->hlslLanguageScope; + break; + case SourceLanguage::GLSL: + languageScope = getSession()->glslLanguageScope; + break; + case SourceLanguage::Slang: + default: + languageScope = getSession()->slangLanguageScope; + break; + } + return languageScope; +} + +Dictionary<String, String> TranslationUnitRequest::getCombinedPreprocessorDefinitions() +{ + Dictionary<String, String> combinedPreprocessorDefinitions; + for (const auto& def : preprocessorDefinitions) + combinedPreprocessorDefinitions.addIfNotExists(def); + for (const auto& def : compileRequest->optionSet.getArray(CompilerOptionName::MacroDefine)) + combinedPreprocessorDefinitions.addIfNotExists(def.stringValue, def.stringValue2); + + // Define standard macros, if not already defined. This style assumes using `#if __SOME_VAR` + // style, as in + // + // ``` + // #if __SLANG_COMPILER__ + // ``` + // + // This choice is made because slang outputs a warning on using a variable in an #if if not + // defined + // + // Of course this means using #ifndef/#ifdef/defined() is probably not appropraite with thes + // variables. + { + // Used to identify level of HLSL language compatibility + combinedPreprocessorDefinitions.addIfNotExists("__HLSL_VERSION", "2018"); + + // Indicates this is being compiled by the slang *compiler* + combinedPreprocessorDefinitions.addIfNotExists("__SLANG_COMPILER__", "1"); + + // Set macro depending on source type + switch (sourceLanguage) + { + case SourceLanguage::HLSL: + // Used to indicate compiled as HLSL language + combinedPreprocessorDefinitions.addIfNotExists("__HLSL__", "1"); + break; + case SourceLanguage::Slang: + // Used to indicate compiled as Slang language + combinedPreprocessorDefinitions.addIfNotExists("__SLANG__", "1"); + break; + default: + break; + } + + // If not set, define as 0. + combinedPreprocessorDefinitions.addIfNotExists("__HLSL__", "0"); + combinedPreprocessorDefinitions.addIfNotExists("__SLANG__", "0"); + } + + return combinedPreprocessorDefinitions; +} + +void TranslationUnitRequest::addSourceArtifact(IArtifact* sourceArtifact) +{ + SLANG_ASSERT(sourceArtifact); + m_sourceArtifacts.add(ComPtr<IArtifact>(sourceArtifact)); +} + + +void TranslationUnitRequest::addSource(IArtifact* sourceArtifact, SourceFile* sourceFile) +{ + SLANG_ASSERT(sourceArtifact && sourceFile); + // Must be in sync! + SLANG_ASSERT(m_sourceFiles.getCount() == m_sourceArtifacts.getCount()); + + addSourceArtifact(sourceArtifact); + _addSourceFile(sourceFile); +} + +void TranslationUnitRequest::addIncludedSourceFileIfNotExist(SourceFile* sourceFile) +{ + if (m_includedFileSet.contains(sourceFile)) + return; + + sourceFile->setIncludedFile(); + m_sourceFiles.add(sourceFile); + m_includedFileSet.add(sourceFile); +} + +PathInfo TranslationUnitRequest::_findSourcePathInfo(IArtifact* artifact) +{ + auto pathRep = findRepresentation<IPathArtifactRepresentation>(artifact); + + if (pathRep && pathRep->getPathType() == SLANG_PATH_TYPE_FILE) + { + // See if we have a unique identity set with the path + if (const auto uniqueIdentity = pathRep->getUniqueIdentity()) + { + return PathInfo::makeNormal(pathRep->getPath(), uniqueIdentity); + } + + // If we couldn't get a unique identity, just use the path + return PathInfo::makePath(pathRep->getPath()); + } + + // If there isn't a path, we can try with the name + const char* name = artifact->getName(); + if (name && name[0] != 0) + { + return PathInfo::makeFromString(name); + } + + return PathInfo::makeUnknown(); +} + +SlangResult TranslationUnitRequest::requireSourceFiles() +{ + SLANG_ASSERT(m_sourceFiles.getCount() <= m_sourceArtifacts.getCount()); + + if (m_sourceFiles.getCount() == m_sourceArtifacts.getCount()) + { + return SLANG_OK; + } + + auto sink = compileRequest->getSink(); + SourceManager* sourceManager = compileRequest->getSourceManager(); + + for (Index i = m_sourceFiles.getCount(); i < m_sourceArtifacts.getCount(); ++i) + { + IArtifact* artifact = m_sourceArtifacts[i]; + + const PathInfo pathInfo = _findSourcePathInfo(artifact); + + SourceFile* sourceFile = nullptr; + ComPtr<ISlangBlob> blob; + + // If we have a unique identity see if we have it already + if (pathInfo.hasUniqueIdentity()) + { + // See if this an already loaded source file + sourceFile = sourceManager->findSourceFileRecursively(pathInfo.uniqueIdentity); + // If we have a sourceFile see if it has a blob + if (sourceFile) + { + blob = sourceFile->getContentBlob(); + } + } + + // If we *don't* have a blob try and get a blob from the artifact + if (!blob) + { + const SlangResult res = artifact->loadBlob(ArtifactKeep::Yes, blob.writeRef()); + if (SLANG_FAILED(res)) + { + // Report couldn't load + sink->diagnose(SourceLoc(), Diagnostics::cannotOpenFile, pathInfo.getName()); + return res; + } + } + + // If we don't have a blob on the artifact we can now add the one we have + if (!findRepresentation<ISlangBlob>(artifact)) + { + artifact->addRepresentationUnknown(blob); + } + + // If we have a sourceFile check if it has contents, and set the blob if doesn't + if (sourceFile) + { + if (!sourceFile->getContentBlob()) + { + sourceFile->setContents(blob); + } + } + else + { + // Create a new source file, using the pathInfo and blob + sourceFile = sourceManager->createSourceFileWithBlob(pathInfo, blob); + } + + auto uniqueIdentity = pathInfo.getMostUniqueIdentity(); + if (uniqueIdentity.getLength()) + sourceManager->addSourceFileIfNotExist(uniqueIdentity, sourceFile); + + // Finally add the source file + _addSourceFile(sourceFile); + } + + return SLANG_OK; +} + +void TranslationUnitRequest::_addSourceFile(SourceFile* sourceFile) +{ + m_sourceFiles.add(sourceFile); + + getModule()->addFileDependency(sourceFile); + getModule()->getIncludedSourceFileMap().add(sourceFile, nullptr); +} + +List<SourceFile*> const& TranslationUnitRequest::getSourceFiles() +{ + return m_sourceFiles; +} + +} // namespace Slang diff --git a/source/slang/slang-translation-unit.h b/source/slang/slang-translation-unit.h new file mode 100644 index 000000000..ad19cb9ac --- /dev/null +++ b/source/slang/slang-translation-unit.h @@ -0,0 +1,117 @@ +// slang-translation-unit.h +#pragma once + +// +// This file provides the `TranslationUnitRequest` class, +// which is used to represent the inputs to front-end compilation +// that will yield a single `Module`. +// + +#include "../compiler-core/slang-artifact.h" +#include "../compiler-core/slang-source-loc.h" +#include "../core/slang-smart-pointer.h" +#include "slang-compiler-fwd.h" +#include "slang-entry-point.h" +#include "slang-module.h" +#include "slang-profile.h" + +namespace Slang +{ + +/// A request for the front-end to compile a translation unit. +class TranslationUnitRequest : public RefObject +{ +public: + TranslationUnitRequest(FrontEndCompileRequest* compileRequest); + TranslationUnitRequest(FrontEndCompileRequest* compileRequest, Module* m); + + // The parent compile request + FrontEndCompileRequest* compileRequest = nullptr; + + // The language in which the source file(s) + // are assumed to be written + SourceLanguage sourceLanguage = SourceLanguage::Unknown; + + /// Makes any source artifact available as a SourceFile. + /// If successful any of the source artifacts will be represented by the same index + /// of sourceArtifacts + SlangResult requireSourceFiles(); + + /// Get the source files. + /// Since lazily evaluated requires calling requireSourceFiles to know it's in sync + /// with sourceArtifacts. + List<SourceFile*> const& getSourceFiles(); + + /// Get the source artifacts associated + const List<ComPtr<IArtifact>>& getSourceArtifacts() const { return m_sourceArtifacts; } + + /// Clear all of the source + void clearSource() + { + m_sourceArtifacts.clear(); + m_sourceFiles.clear(); + m_includedFileSet.clear(); + } + + /// Add a source artifact + void addSourceArtifact(IArtifact* sourceArtifact); + + /// Add both the artifact and the sourceFile. + void addSource(IArtifact* sourceArtifact, SourceFile* sourceFile); + + void addIncludedSourceFileIfNotExist(SourceFile* sourceFile); + + // The entry points associated with this translation unit + List<RefPtr<EntryPoint>> const& getEntryPoints() { return module->getEntryPoints(); } + + void _addEntryPoint(EntryPoint* entryPoint) { module->_addEntryPoint(entryPoint); } + + // Preprocessor definitions to use for this translation unit only + // (whereas the ones on `compileRequest` will be shared) + Dictionary<String, String> preprocessorDefinitions; + + /// The name that will be used for the module this translation unit produces. + Name* moduleName = nullptr; + + /// Result of compiling this translation unit (a module) + RefPtr<Module> module; + + bool isChecked = false; + + Module* getModule() { return module; } + ModuleDecl* getModuleDecl() { return module->getModuleDecl(); } + + Session* getSession(); + NamePool* getNamePool(); + SourceManager* getSourceManager(); + + Scope* getLanguageScope(); + + Dictionary<String, String> getCombinedPreprocessorDefinitions(); + + void setModuleName(Name* name) + { + moduleName = name; + if (module) + module->setName(name); + } + +protected: + void _addSourceFile(SourceFile* sourceFile); + /* Given an artifact, find a PathInfo. + If no PathInfo can be found will return an unknown PathInfo */ + PathInfo _findSourcePathInfo(IArtifact* artifact); + + List<ComPtr<IArtifact>> m_sourceArtifacts; + // The source file(s) that will be compiled to form this translation unit + // + // Usually, for HLSL or GLSL there will be only one file. + // NOTE! This member is generated lazily from m_sourceArtifacts + // it is *necessary* to call requireSourceFiles to ensure it's in sync. + List<SourceFile*> m_sourceFiles; + + // Track all the included source files added in m_sourceFiles + HashSet<SourceFile*> m_includedFileSet; +}; + +} // namespace Slang diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 839ba7938..a428e7928 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -48,102 +48,9 @@ // Used to print exception type names in internal-compiler-error messages #include <typeinfo> -extern Slang::String get_slang_cuda_prelude(); -extern Slang::String get_slang_cpp_prelude(); -extern Slang::String get_slang_hlsl_prelude(); - namespace Slang { - -/* static */ const BaseTypeInfo BaseTypeInfo::s_info[Index(BaseType::CountOf)] = { - {0, 0, uint8_t(BaseType::Void)}, - {uint8_t(sizeof(bool)), 0, uint8_t(BaseType::Bool)}, - {uint8_t(sizeof(int8_t)), - BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, - uint8_t(BaseType::Int8)}, - {uint8_t(sizeof(int16_t)), - BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, - uint8_t(BaseType::Int16)}, - {uint8_t(sizeof(int32_t)), - BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, - uint8_t(BaseType::Int)}, - {uint8_t(sizeof(int64_t)), - BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, - uint8_t(BaseType::Int64)}, - {uint8_t(sizeof(uint8_t)), BaseTypeInfo::Flag::Integer, uint8_t(BaseType::UInt8)}, - {uint8_t(sizeof(uint16_t)), BaseTypeInfo::Flag::Integer, uint8_t(BaseType::UInt16)}, - {uint8_t(sizeof(uint32_t)), BaseTypeInfo::Flag::Integer, uint8_t(BaseType::UInt)}, - {uint8_t(sizeof(uint64_t)), BaseTypeInfo::Flag::Integer, uint8_t(BaseType::UInt64)}, - {uint8_t(sizeof(uint16_t)), BaseTypeInfo::Flag::FloatingPoint, uint8_t(BaseType::Half)}, - {uint8_t(sizeof(float)), BaseTypeInfo::Flag::FloatingPoint, uint8_t(BaseType::Float)}, - {uint8_t(sizeof(double)), BaseTypeInfo::Flag::FloatingPoint, uint8_t(BaseType::Double)}, - {uint8_t(sizeof(char)), - BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, - uint8_t(BaseType::Char)}, - {uint8_t(sizeof(intptr_t)), - BaseTypeInfo::Flag::Signed | BaseTypeInfo::Flag::Integer, - uint8_t(BaseType::IntPtr)}, - {uint8_t(sizeof(uintptr_t)), BaseTypeInfo::Flag::Integer, uint8_t(BaseType::UIntPtr)}, -}; - -/* static */ bool BaseTypeInfo::check() -{ - for (Index i = 0; i < SLANG_COUNT_OF(s_info); ++i) - { - if (s_info[i].baseType != i) - { - SLANG_ASSERT(!"Inconsistency between the s_info table and BaseInfo"); - return false; - } - } - return true; -} - -/* static */ UnownedStringSlice BaseTypeInfo::asText(BaseType baseType) -{ - switch (baseType) - { - case BaseType::Void: - return UnownedStringSlice::fromLiteral("void"); - case BaseType::Bool: - return UnownedStringSlice::fromLiteral("bool"); - case BaseType::Int8: - return UnownedStringSlice::fromLiteral("int8_t"); - case BaseType::Int16: - return UnownedStringSlice::fromLiteral("int16_t"); - case BaseType::Int: - return UnownedStringSlice::fromLiteral("int"); - case BaseType::Int64: - return UnownedStringSlice::fromLiteral("int64_t"); - case BaseType::UInt8: - return UnownedStringSlice::fromLiteral("uint8_t"); - case BaseType::UInt16: - return UnownedStringSlice::fromLiteral("uint16_t"); - case BaseType::UInt: - return UnownedStringSlice::fromLiteral("uint"); - case BaseType::UInt64: - return UnownedStringSlice::fromLiteral("uint64_t"); - case BaseType::Half: - return UnownedStringSlice::fromLiteral("half"); - case BaseType::Float: - return UnownedStringSlice::fromLiteral("float"); - case BaseType::Double: - return UnownedStringSlice::fromLiteral("double"); - case BaseType::Char: - return UnownedStringSlice::fromLiteral("char"); - case BaseType::IntPtr: - return UnownedStringSlice::fromLiteral("intptr_t"); - case BaseType::UIntPtr: - return UnownedStringSlice::fromLiteral("uintptr_t"); - default: - { - SLANG_ASSERT(!"Unknown basic type"); - return UnownedStringSlice(); - } - } -} - const char* getBuildTagString() { if (UnownedStringSlice(SLANG_TAG_VERSION) == "0.0.0-unknown") @@ -158,1102 +65,6 @@ const char* getBuildTagString() return SLANG_TAG_VERSION; } - -void Session::init() -{ - SLANG_ASSERT(BaseTypeInfo::check()); - -#if SLANG_ENABLE_IR_BREAK_ALLOC - // Read environment variable for IR debugging - StringBuilder irBreakEnv; - if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable( - UnownedStringSlice("SLANG_DEBUG_IR_BREAK"), - irBreakEnv))) - { - String envValue = irBreakEnv.produceString(); - if (envValue.getLength()) - { - _slangIRAllocBreak = stringToInt(envValue); - _slangIRPrintStackAtBreak = true; - } - } -#endif - - _initCodeGenTransitionMap(); - - ::memset(m_downstreamCompilerLocators, 0, sizeof(m_downstreamCompilerLocators)); - DownstreamCompilerUtil::setDefaultLocators(m_downstreamCompilerLocators); - m_downstreamCompilerSet = new DownstreamCompilerSet; - - m_completionTokenName = getNamePool()->getName("#?"); - - m_sharedLibraryLoader = DefaultSharedLibraryLoader::getSingleton(); - - // Set up the command line options - initCommandOptions(m_commandOptions); - - // Set up shared AST builder - m_sharedASTBuilder = new SharedASTBuilder; - m_sharedASTBuilder->init(this); - - // And the global ASTBuilder - auto builtinAstBuilder = m_sharedASTBuilder->getInnerASTBuilder(); - globalAstBuilder = builtinAstBuilder; - - // Make sure our source manager is initialized - builtinSourceManager.initialize(nullptr, nullptr); - - // Built in linkage uses the built in builder - m_builtinLinkage = new Linkage(this, builtinAstBuilder, nullptr); - m_builtinLinkage->m_optionSet.set(CompilerOptionName::DebugInformation, DebugInfoLevel::None); - - // Because the `Session` retains the builtin `Linkage`, - // we need to make sure that the parent pointer inside - // `Linkage` doesn't create a retain cycle. - // - // This operation ensures that the parent pointer will - // just be a raw pointer, so that the builtin linkage - // doesn't keep the parent session alive. - // - m_builtinLinkage->_stopRetainingParentSession(); - - // Create scopes for various language builtins. - // - // TODO: load these on-demand to avoid parsing - // the core module code for languages the user won't use. - - baseLanguageScope = builtinAstBuilder->create<Scope>(); - - // Will stay in scope as long as ASTBuilder - baseModuleDecl = - populateBaseLanguageModule(m_builtinLinkage->getASTBuilder(), baseLanguageScope); - - coreLanguageScope = builtinAstBuilder->create<Scope>(); - coreLanguageScope->nextSibling = baseLanguageScope; - - hlslLanguageScope = builtinAstBuilder->create<Scope>(); - hlslLanguageScope->nextSibling = coreLanguageScope; - - slangLanguageScope = builtinAstBuilder->create<Scope>(); - slangLanguageScope->nextSibling = hlslLanguageScope; - - glslLanguageScope = builtinAstBuilder->create<Scope>(); - glslLanguageScope->nextSibling = slangLanguageScope; - - glslModuleName = getNameObj("glsl"); - - { - for (Index i = 0; i < Index(SourceLanguage::CountOf); ++i) - { - m_defaultDownstreamCompilers[i] = PassThroughMode::None; - } - m_defaultDownstreamCompilers[Index(SourceLanguage::C)] = PassThroughMode::GenericCCpp; - m_defaultDownstreamCompilers[Index(SourceLanguage::CPP)] = PassThroughMode::GenericCCpp; - m_defaultDownstreamCompilers[Index(SourceLanguage::CUDA)] = PassThroughMode::NVRTC; - } - - // Set up default prelude code for target languages that need a prelude - m_languagePreludes[Index(SourceLanguage::CUDA)] = get_slang_cuda_prelude(); - m_languagePreludes[Index(SourceLanguage::CPP)] = get_slang_cpp_prelude(); - m_languagePreludes[Index(SourceLanguage::HLSL)] = get_slang_hlsl_prelude(); - - if (!spirvCoreGrammarInfo) - spirvCoreGrammarInfo = SPIRVCoreGrammarInfo::getEmbeddedVersion(); -} - -Module* Session::getBuiltinModule(slang::BuiltinModuleName name) -{ - auto info = getBuiltinModuleInfo(name); - auto builtinLinkage = getBuiltinLinkage(); - auto moduleNameObj = builtinLinkage->getNamePool()->getName(info.name); - RefPtr<Module> module; - if (builtinLinkage->mapNameToLoadedModules.tryGetValue(moduleNameObj, module)) - return module.get(); - return nullptr; -} - -void Session::_initCodeGenTransitionMap() -{ - // TODO(JS): Might want to do something about these in the future... - - // PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target); - // SourceLanguage getDefaultSourceLanguageForDownstreamCompiler(PassThroughMode compiler); - - // Set up the default ways to do compilations between code gen targets - auto& map = m_codeGenTransitionMap; - - // TODO(JS): There currently isn't a 'downstream compiler' for direct spirv output. If we did - // it would presumably a transition from SlangIR to SPIRV. - - // For C and C++ we default to use the 'genericCCpp' compiler - { - const CodeGenTarget sources[] = {CodeGenTarget::CSource, CodeGenTarget::CPPSource}; - for (auto source : sources) - { - // We *don't* add a default for host callable, as we will determine what is suitable - // depending on what is available. We prefer LLVM if that's available. If it's not we - // can use generic C/C++ compiler - - map.addTransition( - source, - CodeGenTarget::ShaderSharedLibrary, - PassThroughMode::GenericCCpp); - map.addTransition( - source, - CodeGenTarget::HostSharedLibrary, - PassThroughMode::GenericCCpp); - map.addTransition(source, CodeGenTarget::HostExecutable, PassThroughMode::GenericCCpp); - map.addTransition(source, CodeGenTarget::ObjectCode, PassThroughMode::GenericCCpp); - } - } - - - // Add all the straightforward transitions - map.addTransition(CodeGenTarget::CUDASource, CodeGenTarget::PTX, PassThroughMode::NVRTC); - map.addTransition(CodeGenTarget::HLSL, CodeGenTarget::DXBytecode, PassThroughMode::Fxc); - map.addTransition(CodeGenTarget::HLSL, CodeGenTarget::DXIL, PassThroughMode::Dxc); - map.addTransition(CodeGenTarget::GLSL, CodeGenTarget::SPIRV, PassThroughMode::Glslang); - map.addTransition(CodeGenTarget::Metal, CodeGenTarget::MetalLib, PassThroughMode::MetalC); - map.addTransition(CodeGenTarget::WGSL, CodeGenTarget::WGSLSPIRV, PassThroughMode::Tint); - // To assembly - map.addTransition(CodeGenTarget::SPIRV, CodeGenTarget::SPIRVAssembly, PassThroughMode::Glslang); - // We use glslang to turn SPIR-V into SPIR-V assembly. - map.addTransition( - CodeGenTarget::WGSLSPIRV, - CodeGenTarget::WGSLSPIRVAssembly, - PassThroughMode::Glslang); - map.addTransition(CodeGenTarget::DXIL, CodeGenTarget::DXILAssembly, PassThroughMode::Dxc); - map.addTransition( - CodeGenTarget::DXBytecode, - CodeGenTarget::DXBytecodeAssembly, - PassThroughMode::Fxc); - map.addTransition( - CodeGenTarget::MetalLib, - CodeGenTarget::MetalLibAssembly, - PassThroughMode::MetalC); -} - -void Session::addBuiltins(char const* sourcePath, char const* source) -{ - auto sourceBlob = StringBlob::moveCreate(String(source)); - - // TODO(tfoley): Add ability to directly new builtins to the appropriate scope - Module* module = nullptr; - addBuiltinSource(coreLanguageScope, sourcePath, sourceBlob, module); - if (module) - coreModules.add(module); -} - -void Session::setSharedLibraryLoader(ISlangSharedLibraryLoader* loader) -{ - // External API allows passing of nullptr to reset the loader - loader = loader ? loader : DefaultSharedLibraryLoader::getSingleton(); - - _setSharedLibraryLoader(loader); -} - -ISlangSharedLibraryLoader* Session::getSharedLibraryLoader() -{ - return (m_sharedLibraryLoader == DefaultSharedLibraryLoader::getSingleton()) - ? nullptr - : m_sharedLibraryLoader.get(); -} - -SlangResult Session::checkCompileTargetSupport(SlangCompileTarget inTarget) -{ - auto target = CodeGenTarget(inTarget); - - const PassThroughMode mode = getDownstreamCompilerRequiredForTarget(target); - return (mode != PassThroughMode::None) ? checkPassThroughSupport(SlangPassThrough(mode)) - : SLANG_OK; -} - -SlangResult Session::checkPassThroughSupport(SlangPassThrough inPassThrough) -{ - return checkExternalCompilerSupport(this, PassThroughMode(inPassThrough)); -} - -void Session::writeCoreModuleDoc(String config) -{ - ASTBuilder* astBuilder = getBuiltinLinkage()->getASTBuilder(); - SourceManager* sourceManager = getBuiltinSourceManager(); - - DiagnosticSink sink(sourceManager, Lexer::sourceLocationLexer); - - List<String> docStrings; - - // For all the modules add their doc output to docStrings - for (Module* m : coreModules) - { - RefPtr<ASTMarkup> markup(new ASTMarkup); - ASTMarkupUtil::extract(m->getModuleDecl(), sourceManager, &sink, markup); - - DocMarkdownWriter writer(markup, astBuilder, &sink); - auto rootPage = writer.writeAll(config.getUnownedSlice()); - File::writeAllText("toc.html", writer.writeTOC()); - rootPage->writeToDisk(); - rootPage->writeSummary(toSlice("summary.txt")); - } - ComPtr<ISlangBlob> diagnosticBlob; - sink.getBlobIfNeeded(diagnosticBlob.writeRef()); - if (diagnosticBlob && diagnosticBlob->getBufferSize() != 0) - { - // Write the diagnostic blob to stdout. - fprintf(stderr, "%s", (const char*)diagnosticBlob->getBufferPointer()); - } -} - -const char* getBuiltinModuleNameStr(slang::BuiltinModuleName name) -{ - const char* result = nullptr; - switch (name) - { - case slang::BuiltinModuleName::Core: - result = "core"; - break; - case slang::BuiltinModuleName::GLSL: - result = "glsl"; - break; - default: - SLANG_UNEXPECTED("Unknown builtin module"); - } - return result; -} - -TypeCheckingCache* Session::getTypeCheckingCache() -{ - return static_cast<TypeCheckingCache*>(m_typeCheckingCache.get()); -} - -Session::BuiltinModuleInfo Session::getBuiltinModuleInfo(slang::BuiltinModuleName name) -{ - Session::BuiltinModuleInfo result; - - result.name = getBuiltinModuleNameStr(name); - - switch (name) - { - case slang::BuiltinModuleName::Core: - result.languageScope = coreLanguageScope; - break; - case slang::BuiltinModuleName::GLSL: - result.name = "glsl"; - result.languageScope = glslLanguageScope; - break; - default: - SLANG_UNEXPECTED("Unknown builtin module"); - } - return result; -} - -SlangResult Session::compileCoreModule(slang::CompileCoreModuleFlags compileFlags) -{ - return compileBuiltinModule(slang::BuiltinModuleName::Core, compileFlags); -} - -void Session::getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName) -{ - switch (moduleName) - { - case slang::BuiltinModuleName::Core: - sb << (const char*)getCoreLibraryCode()->getBufferPointer() - << (const char*)getHLSLLibraryCode()->getBufferPointer() - << (const char*)getAutodiffLibraryCode()->getBufferPointer(); - break; - case slang::BuiltinModuleName::GLSL: - sb << (const char*)getGLSLLibraryCode()->getBufferPointer(); - break; - } -} - -SlangResult Session::compileBuiltinModule( - slang::BuiltinModuleName moduleName, - slang::CompileCoreModuleFlags compileFlags) -{ - SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder()); - -#ifdef _DEBUG - time_t beginTime = 0; - if (moduleName == slang::BuiltinModuleName::Core) - { - // Print a message in debug builds to notice the user that compiling the core module - // can take a while. - time(&beginTime); - fprintf(stderr, "Compiling core module on debug build, this can take a while.\n"); - } -#endif - BuiltinModuleInfo builtinModuleInfo = getBuiltinModuleInfo(moduleName); - auto moduleNameObj = m_builtinLinkage->getNamePool()->getName(builtinModuleInfo.name); - if (m_builtinLinkage->mapNameToLoadedModules.tryGetValue(moduleNameObj)) - { - // Already have the builtin module loaded - return SLANG_FAIL; - } - - StringBuilder moduleSrcBuilder; - getBuiltinModuleSource(moduleSrcBuilder, moduleName); - - // TODO(JS): Could make this return a SlangResult as opposed to exception - auto moduleSrcBlob = StringBlob::moveCreate(moduleSrcBuilder.produceString()); - Module* compiledModule = nullptr; - addBuiltinSource( - builtinModuleInfo.languageScope, - builtinModuleInfo.name, - moduleSrcBlob, - compiledModule); - - if (moduleName == slang::BuiltinModuleName::Core) - { - // We need to retain this AST so that we can use it in other code - // (Note that the `Scope` type does not retain the AST it points to) - coreModules.add(compiledModule); - } - - if (compileFlags & slang::CompileCoreModuleFlag::WriteDocumentation) - { - // Load config file first. - String configText; - if (SLANG_FAILED(File::readAllText("config.txt", configText))) - { - fprintf( - stderr, - "Error writing documentation: config file not found on current working " - "directory.\n"); - } - else - { - writeCoreModuleDoc(configText); - } - } - - finalizeSharedASTBuilder(); - -#ifdef _DEBUG - if (moduleName == slang::BuiltinModuleName::Core) - { - time_t endTime; - time(&endTime); - fprintf(stderr, "Compiling core module took %.2f seconds.\n", difftime(endTime, beginTime)); - } -#endif - return SLANG_OK; -} - -SlangResult Session::loadCoreModule(const void* coreModule, size_t coreModuleSizeInBytes) -{ - return loadBuiltinModule(slang::BuiltinModuleName::Core, coreModule, coreModuleSizeInBytes); -} - -SlangResult Session::loadBuiltinModule( - slang::BuiltinModuleName moduleName, - const void* moduleData, - size_t sizeInBytes) -{ - SLANG_PROFILE; - - - SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder()); - - BuiltinModuleInfo builtinModuleInfo = getBuiltinModuleInfo(moduleName); - auto nameObj = m_builtinLinkage->getNamePool()->getName(builtinModuleInfo.name); - if (m_builtinLinkage->mapNameToLoadedModules.containsKey(nameObj)) - { - // Already have a core module loaded - return SLANG_FAIL; - } - - // Make a file system to read it from - ComPtr<ISlangFileSystemExt> fileSystem; - SLANG_RETURN_ON_FAIL(loadArchiveFileSystem(moduleData, sizeInBytes, fileSystem)); - - // Let's try loading serialized modules and adding them - Module* module = nullptr; - SLANG_RETURN_ON_FAIL(_readBuiltinModule( - fileSystem, - builtinModuleInfo.languageScope, - builtinModuleInfo.name, - module)); - - if (moduleName == slang::BuiltinModuleName::Core) - { - // We need to retain this AST so that we can use it in other code - // (Note that the `Scope` type does not retain the AST it points to) - coreModules.add(module); - } - - finalizeSharedASTBuilder(); - return SLANG_OK; -} - -SlangResult Session::saveCoreModule(SlangArchiveType archiveType, ISlangBlob** outBlob) -{ - return saveBuiltinModule(slang::BuiltinModuleName::Core, archiveType, outBlob); -} - -SlangResult Session::saveBuiltinModule( - slang::BuiltinModuleName moduleTag, - SlangArchiveType archiveType, - ISlangBlob** outBlob) -{ - // If no builtin modules have been loaded, then there is - // nothing to save, and we fail immediately. - // - if (m_builtinLinkage->mapNameToLoadedModules.getCount() == 0) - { - return SLANG_FAIL; - } - - // The module will need to be looked up by its name, and - // will also be serialized out to a path with a matching name. - // - BuiltinModuleInfo moduleInfo = getBuiltinModuleInfo(moduleTag); - const char* moduleName = moduleInfo.name; - - // If we cannot find a loaded module in the linkage with - // the appropriate name, then for some reason it hasn't - // been loaded, and we fail. - // - RefPtr<Module> module; - m_builtinLinkage->mapNameToLoadedModules.tryGetValue( - getNameObj(UnownedStringSlice(moduleName)), - module); - if (!module) - { - return SLANG_FAIL; - } - - // AST serialization needs access to an AST builder, so - // we establish a current builder for the duration of - // the serialization process. - // - SLANG_AST_BUILDER_RAII(m_builtinLinkage->getASTBuilder()); - - // The serialized module will be represented as a logical - // file in an archive, so we create a logical file system - // to represent that archive. - // - ComPtr<ISlangMutableFileSystem> fileSystem; - SLANG_RETURN_ON_FAIL(createArchiveFileSystem(archiveType, fileSystem)); - // - // The created file system must support the `IArchiveFileSystem` - // interface (since we created it with `createArchiveFileSystem`). - // - auto archiveFileSystem = as<IArchiveFileSystem>(fileSystem); - if (!archiveFileSystem) - { - return SLANG_FAIL; - } - - // The output file name that we'll write to in that file system - // is just the builtin module name with a `.slang-module` suffix. - // - StringBuilder moduleFileName; - moduleFileName << moduleName << ".slang-module"; - - // The module serialization step has some options that we need - // to configure appropriately. - // - SerialContainerUtil::WriteOptions options; - // - // We want builtin modules to be saved with their source location - // information. - // - // And in order to work with source locations, the serialization - // process will also need access to the source manager that - // can translate locations into their humane format. - // - options.sourceManagerToUseWhenSerializingSourceLocs = m_builtinLinkage->getSourceManager(); - - // At this point we can finally delegate down to the next level, - // which handles the serialization of a Slang module into a - // byte stream. - // - OwnedMemoryStream stream(FileAccess::Write); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(module, options, &stream)); - auto contents = stream.getContents(); - - // Once the stream that represents the module has been written, we can - // write it to a file in the logical file system. - // - // TODO(tfoley): why can't the file system let us open the file for output? - // - SLANG_RETURN_ON_FAIL(fileSystem->saveFile( - moduleFileName.getBuffer(), - contents.getBuffer(), - contents.getCount())); - - // And finally, we can ask the archive file system to serialize itself - // out as a blob of bytes, which yields the final serialized representation - // of the module. - // - SLANG_RETURN_ON_FAIL(archiveFileSystem->storeArchive( - // The `true` here indicates that the blob that gets created should own - // its content, independent from the file system object itself; otherwise - // the file system might return a blob that shares storage with itself. - true, - outBlob)); - - return SLANG_OK; -} - -SlangResult Session::_readBuiltinModule( - ISlangFileSystem* fileSystem, - Scope* scope, - String moduleName, - Module*& outModule) -{ - // Get the name of the module - StringBuilder moduleFilename; - moduleFilename << moduleName << ".slang-module"; - - // Load it - ComPtr<ISlangBlob> fileContents; - SLANG_RETURN_ON_FAIL(fileSystem->loadFile(moduleFilename.getBuffer(), fileContents.writeRef())); - - RIFF::RootChunk const* rootChunk = RIFF::RootChunk::getFromBlob(fileContents); - if (!rootChunk) - { - return SLANG_FAIL; - } - - Linkage* linkage = getBuiltinLinkage(); - SourceManager* sourceManager = getBuiltinSourceManager(); - NamePool* sessionNamePool = &namePool; - - auto moduleChunk = ModuleChunk::find(rootChunk); - if (!moduleChunk) - return SLANG_FAIL; - - SHA1::Digest moduleDigest = moduleChunk->getDigest(); - - auto irChunk = moduleChunk->findIR(); - if (!irChunk) - return SLANG_FAIL; - - auto astChunk = moduleChunk->findAST(); - if (!astChunk) - return SLANG_FAIL; - - // Source location information is stored as a distinct - // chunk from the IR and AST, so we need to search for - // that chunk and then set up the information for use - // in the IR and AST deserialization (if we find anything). - // - RefPtr<SerialSourceLocReader> sourceLocReader; - if (auto debugChunk = DebugChunk::find(moduleChunk)) - { - SLANG_RETURN_ON_FAIL( - readSourceLocationsFromDebugChunk(debugChunk, sourceManager, sourceLocReader)); - } - - // At this point we create the `Module` object that will - // represent the builtin module we are reading, although - // it is still possible that deserialization will fail - // at one of the following steps. - // - auto astBuilder = linkage->getASTBuilder(); - RefPtr<Module> module(new Module(linkage, astBuilder)); - module->setName(moduleName); - module->setDigest(moduleDigest); - - - // Next, we set about deserializing the AST representation - // of the module. - // - auto moduleDecl = readSerializedModuleAST( - linkage, - astBuilder, - nullptr, // no sink - fileContents, - astChunk, - sourceLocReader, - SourceLoc()); - if (!moduleDecl) - { - return SLANG_FAIL; - } - moduleDecl->module = module; - module->setModuleDecl(moduleDecl); - - // After the AST module has been read in, we next look - // to deserialize the IR module. - // - RefPtr<IRModule> irModule; - SLANG_RETURN_ON_FAIL(readSerializedModuleIR(irChunk, this, sourceLocReader, irModule)); - - irModule->setName(module->getNameObj()); - module->setIRModule(irModule); - - // Put in the loaded module map - linkage->mapNameToLoadedModules.add(sessionNamePool->getName(moduleName), module); - - - // Add the resulting code to the appropriate scope - if (!scope->containerDecl) - { - // We are the first chunk of code to be loaded for this scope - scope->containerDecl = moduleDecl; - } - else - { - // We need to create a new scope to link into the whole thing - auto subScope = linkage->getASTBuilder()->create<Scope>(); - subScope->containerDecl = moduleDecl; - subScope->nextSibling = scope->nextSibling; - scope->nextSibling = subScope; - } - - outModule = module.get(); - - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL -Session::queryInterface(SlangUUID const& uuid, void** outObject) -{ - if (uuid == Session::getTypeGuid()) - { - addReference(); - *outObject = static_cast<Session*>(this); - return SLANG_OK; - } - - if (uuid == ISlangUnknown::getTypeGuid() || uuid == IGlobalSession::getTypeGuid()) - { - addReference(); - *outObject = static_cast<slang::IGlobalSession*>(this); - return SLANG_OK; - } - - return SLANG_E_NO_INTERFACE; -} - -static size_t _getStructureSize(const uint8_t* src) -{ - size_t size = 0; - ::memcpy(&size, src, sizeof(size_t)); - return size; -} - -template<typename T> -static T makeFromSizeVersioned(const uint8_t* src) -{ - // The structure size must be size_t - SLANG_COMPILE_TIME_ASSERT(sizeof(((T*)src)->structureSize) == sizeof(size_t)); - - // The structureSize field *must* be the first element of T - // Ideally would use SLANG_COMPILE_TIME_ASSERT, but that doesn't work on gcc. - // Can't just assert, because determined to be a constant expression - { - auto offset = SLANG_OFFSET_OF(T, structureSize); - SLANG_ASSERT(offset == 0); - // Needed because offset is only 'used' by an assert - SLANG_UNUSED(offset); - } - - // The source size is held in the first element of T, and will be in the first bytes of src. - const size_t srcSize = _getStructureSize(src); - const size_t dstSize = sizeof(T); - - // If they are the same size, and appropriate alignment we can just cast and return - if (srcSize == dstSize && (size_t(src) & (alignof(T) - 1)) == 0) - { - return *(const T*)src; - } - - // Assumes T can default constructed sensibly - T dst; - - // It's structure size should be setup and should be dstSize - SLANG_ASSERT(dst.structureSize == dstSize); - - // The size to copy is the minimum on the two sizes - const auto copySize = std::min(srcSize, dstSize); - ::memcpy(&dst, src, copySize); - - // The final struct size is the destination size - dst.structureSize = dstSize; - - return dst; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL -Session::createSession(slang::SessionDesc const& inDesc, slang::ISession** outSession) -{ - RefPtr<ASTBuilder> astBuilder(new ASTBuilder(m_sharedASTBuilder, "Session::astBuilder")); - slang::SessionDesc desc = makeFromSizeVersioned<slang::SessionDesc>((uint8_t*)&inDesc); - - RefPtr<Linkage> linkage = new Linkage(this, astBuilder, getBuiltinLinkage()); - - if (desc.skipSPIRVValidation) - { - linkage->m_optionSet.set(CompilerOptionName::SkipSPIRVValidation, true); - } - - { - std::lock_guard<std::mutex> lock(m_typeCheckingCacheMutex); - if (m_typeCheckingCache) - linkage->m_typeCheckingCache = - new TypeCheckingCache(*static_cast<TypeCheckingCache*>(m_typeCheckingCache.get())); - } - - Int searchPathCount = desc.searchPathCount; - for (Int ii = 0; ii < searchPathCount; ++ii) - { - linkage->addSearchPath(desc.searchPaths[ii]); - } - - Int macroCount = desc.preprocessorMacroCount; - for (Int ii = 0; ii < macroCount; ++ii) - { - auto& macro = desc.preprocessorMacros[ii]; - linkage->addPreprocessorDefine(macro.name, macro.value); - } - - if (desc.fileSystem) - { - linkage->setFileSystem(desc.fileSystem); - } - - if (desc.structureSize >= offsetof(slang::SessionDesc, enableEffectAnnotations)) - { - linkage->m_optionSet.set( - CompilerOptionName::EnableEffectAnnotations, - desc.enableEffectAnnotations); - } - - linkage->m_optionSet.load(desc.compilerOptionEntryCount, desc.compilerOptionEntries); - - if (!linkage->m_optionSet.hasOption(CompilerOptionName::MatrixLayoutColumn) && - !linkage->m_optionSet.hasOption(CompilerOptionName::MatrixLayoutRow)) - linkage->setMatrixLayoutMode(desc.defaultMatrixLayoutMode); - - { - const Int targetCount = desc.targetCount; - const uint8_t* targetDescPtr = reinterpret_cast<const uint8_t*>(desc.targets); - for (Int ii = 0; ii < targetCount; ++ii, targetDescPtr += _getStructureSize(targetDescPtr)) - { - const auto targetDesc = makeFromSizeVersioned<slang::TargetDesc>(targetDescPtr); - linkage->addTarget(targetDesc); - } - } - - // If any target requires debug info, then we will need to enable debug info when lowering to - // target-agnostic IR. The target-agnostic IR will only include debug info if the linkage IR - // options specify that it should, so make sure the linkage debug info level is greater than or - // equal to that of any target. - DebugInfoLevel linkageDebugInfoLevel = linkage->m_optionSet.getDebugInfoLevel(); - for (auto target : linkage->targets) - linkageDebugInfoLevel = - Math::Max(linkageDebugInfoLevel, target->getOptionSet().getDebugInfoLevel()); - linkage->m_optionSet.set(CompilerOptionName::DebugInformation, linkageDebugInfoLevel); - - // Add any referenced modules to the linkage - for (auto& option : linkage->m_optionSet.options) - { - if (option.key != CompilerOptionName::ReferenceModule) - continue; - for (auto& path : option.value) - { - DiagnosticSink sink; - ComPtr<IArtifact> artifact; - SlangResult result = createArtifactFromReferencedModule( - path.stringValue, - SourceLoc{}, - &sink, - artifact.writeRef()); - if (SLANG_FAILED(result)) - { - sink.diagnose(SourceLoc{}, Diagnostics::unableToReadFile, path.stringValue); - return result; - } - linkage->m_libModules.add(artifact); - } - } - - *outSession = asExternal(linkage.detach()); - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL -Session::createCompileRequest(slang::ICompileRequest** outCompileRequest) -{ - auto req = new EndToEndCompileRequest(this); - - // Give it a ref (for output) - req->addRef(); - // Check it is what we think it should be - SLANG_ASSERT(req->debugGetReferenceCount() == 1); - - *outCompileRequest = req; - return SLANG_OK; -} - -SLANG_NO_THROW SlangProfileID SLANG_MCALL Session::findProfile(char const* name) -{ - return SlangProfileID(Slang::Profile::lookUp(name).raw); -} - -SLANG_NO_THROW SlangCapabilityID SLANG_MCALL Session::findCapability(char const* name) -{ - return SlangCapabilityID(Slang::findCapabilityName(UnownedTerminatedStringSlice(name))); -} - -SLANG_NO_THROW void SLANG_MCALL -Session::setDownstreamCompilerPath(SlangPassThrough inPassThrough, char const* path) -{ - PassThroughMode passThrough = PassThroughMode(inPassThrough); - SLANG_ASSERT( - int(passThrough) > int(PassThroughMode::None) && - int(passThrough) < int(PassThroughMode::CountOf)); - - if (m_downstreamCompilerPaths[int(passThrough)] != path) - { - // Make access redetermine compiler - resetDownstreamCompiler(passThrough); - // Set the path - m_downstreamCompilerPaths[int(passThrough)] = path; - } -} - -SLANG_NO_THROW void SLANG_MCALL -Session::setDownstreamCompilerPrelude(SlangPassThrough inPassThrough, char const* prelude) -{ - PassThroughMode downstreamCompiler = PassThroughMode(inPassThrough); - SLANG_ASSERT( - int(downstreamCompiler) > int(PassThroughMode::None) && - int(downstreamCompiler) < int(PassThroughMode::CountOf)); - const SourceLanguage sourceLanguage = - getDefaultSourceLanguageForDownstreamCompiler(downstreamCompiler); - setLanguagePrelude(SlangSourceLanguage(sourceLanguage), prelude); -} - -SLANG_NO_THROW void SLANG_MCALL -Session::getDownstreamCompilerPrelude(SlangPassThrough inPassThrough, ISlangBlob** outPrelude) -{ - PassThroughMode downstreamCompiler = PassThroughMode(inPassThrough); - SLANG_ASSERT( - int(downstreamCompiler) > int(PassThroughMode::None) && - int(downstreamCompiler) < int(PassThroughMode::CountOf)); - const SourceLanguage sourceLanguage = - getDefaultSourceLanguageForDownstreamCompiler(downstreamCompiler); - getLanguagePrelude(SlangSourceLanguage(sourceLanguage), outPrelude); -} - -SLANG_NO_THROW void SLANG_MCALL -Session::setLanguagePrelude(SlangSourceLanguage inSourceLanguage, char const* prelude) -{ - SourceLanguage sourceLanguage = SourceLanguage(inSourceLanguage); - SLANG_ASSERT( - int(sourceLanguage) > int(SourceLanguage::Unknown) && - int(sourceLanguage) < int(SourceLanguage::CountOf)); - - SLANG_ASSERT(sourceLanguage != SourceLanguage::Unknown); - - if (sourceLanguage != SourceLanguage::Unknown) - { - m_languagePreludes[int(sourceLanguage)] = prelude; - } -} - -SLANG_NO_THROW void SLANG_MCALL -Session::getLanguagePrelude(SlangSourceLanguage inSourceLanguage, ISlangBlob** outPrelude) -{ - SourceLanguage sourceLanguage = SourceLanguage(inSourceLanguage); - - *outPrelude = nullptr; - if (sourceLanguage != SourceLanguage::Unknown) - { - SLANG_ASSERT( - int(sourceLanguage) > int(SourceLanguage::Unknown) && - int(sourceLanguage) < int(SourceLanguage::CountOf)); - *outPrelude = - Slang::StringUtil::createStringBlob(m_languagePreludes[int(sourceLanguage)]).detach(); - } -} - -SLANG_NO_THROW const char* SLANG_MCALL Session::getBuildTagString() -{ - return ::Slang::getBuildTagString(); -} - -SLANG_NO_THROW SlangResult SLANG_MCALL Session::setDefaultDownstreamCompiler( - SlangSourceLanguage sourceLanguage, - SlangPassThrough defaultCompiler) -{ - if (DownstreamCompilerInfo::canCompile(defaultCompiler, sourceLanguage)) - { - m_defaultDownstreamCompilers[int(sourceLanguage)] = PassThroughMode(defaultCompiler); - return SLANG_OK; - } - return SLANG_FAIL; -} - -SlangPassThrough SLANG_MCALL -Session::getDefaultDownstreamCompiler(SlangSourceLanguage inSourceLanguage) -{ - SLANG_ASSERT(inSourceLanguage >= 0 && inSourceLanguage < SLANG_SOURCE_LANGUAGE_COUNT_OF); - auto sourceLanguage = SourceLanguage(inSourceLanguage); - return SlangPassThrough(m_defaultDownstreamCompilers[int(sourceLanguage)]); -} - -void Session::setDownstreamCompilerForTransition( - SlangCompileTarget source, - SlangCompileTarget target, - SlangPassThrough compiler) -{ - if (compiler == SLANG_PASS_THROUGH_NONE) - { - // Removing the transition means a default can be used - m_codeGenTransitionMap.removeTransition(CodeGenTarget(source), CodeGenTarget(target)); - } - else - { - m_codeGenTransitionMap.addTransition( - CodeGenTarget(source), - CodeGenTarget(target), - PassThroughMode(compiler)); - } -} - -SlangPassThrough Session::getDownstreamCompilerForTransition( - SlangCompileTarget inSource, - SlangCompileTarget inTarget) -{ - const CodeGenTarget source = CodeGenTarget(inSource); - const CodeGenTarget target = CodeGenTarget(inTarget); - - if (m_codeGenTransitionMap.hasTransition(source, target)) - { - return (SlangPassThrough)m_codeGenTransitionMap.getTransition(source, target); - } - - const auto desc = ArtifactDescUtil::makeDescForCompileTarget(inTarget); - - // Special case host-callable - if ((desc.kind == ArtifactKind::HostCallable) && - (source == CodeGenTarget::CSource || source == CodeGenTarget::CPPSource)) - { - // We prefer LLVM if it's available - if (const auto llvm = getOrLoadDownstreamCompiler(PassThroughMode::LLVM, nullptr)) - { - return SLANG_PASS_THROUGH_LLVM; - } - } - - // Use the legacy 'sourceLanguage' default mechanism. - // This says nothing about the target type, so it is *assumed* the target type is possible - // If not it will fail when trying to compile to an unknown target - const SourceLanguage sourceLanguage = - (SourceLanguage)TypeConvertUtil::getSourceLanguageFromTarget(inSource); - if (sourceLanguage != SourceLanguage::Unknown) - { - return getDefaultDownstreamCompiler(SlangSourceLanguage(sourceLanguage)); - } - - // Unknwon - return SLANG_PASS_THROUGH_NONE; -} - -IDownstreamCompiler* Session::getDownstreamCompiler(CodeGenTarget source, CodeGenTarget target) -{ - PassThroughMode compilerType = (PassThroughMode)getDownstreamCompilerForTransition( - SlangCompileTarget(source), - SlangCompileTarget(target)); - return getOrLoadDownstreamCompiler(compilerType, nullptr); -} - -SLANG_NO_THROW SlangResult SLANG_MCALL Session::setSPIRVCoreGrammar(char const* jsonPath) -{ - if (!jsonPath) - { - spirvCoreGrammarInfo = SPIRVCoreGrammarInfo::getEmbeddedVersion(); - SLANG_ASSERT(spirvCoreGrammarInfo); - } - else - { - SourceManager* sourceManager = getBuiltinSourceManager(); - SLANG_ASSERT(sourceManager); - DiagnosticSink sink(sourceManager, Lexer::sourceLocationLexer); - - String contents; - const auto readRes = File::readAllText(jsonPath, contents); - if (SLANG_FAILED(readRes)) - { - sink.diagnose(SourceLoc{}, Diagnostics::unableToReadFile, jsonPath); - return readRes; - } - const auto pathInfo = PathInfo::makeFromString(jsonPath); - const auto sourceFile = sourceManager->createSourceFileWithString(pathInfo, contents); - const auto sourceView = sourceManager->createSourceView(sourceFile, nullptr, SourceLoc()); - spirvCoreGrammarInfo = SPIRVCoreGrammarInfo::loadFromJSON(*sourceView, sink); - } - return spirvCoreGrammarInfo ? SLANG_OK : SLANG_FAIL; -} - -struct ParsedCommandLineData : public ISlangUnknown, public ComObject -{ - SLANG_COM_OBJECT_IUNKNOWN_ALL - - ISlangUnknown* getInterface(const Slang::Guid& guid) - { - if (guid == ISlangUnknown::getTypeGuid()) - return this; - return nullptr; - } - List<SerializedOptionsData> options; - List<slang::TargetDesc> targets; -}; - -SLANG_NO_THROW SlangResult SLANG_MCALL Session::parseCommandLineArguments( - int argc, - const char* const* argv, - slang::SessionDesc* outDesc, - ISlangUnknown** outAllocation) -{ - if (outDesc->structureSize < sizeof(slang::SessionDesc)) - return SLANG_E_BUFFER_TOO_SMALL; - RefPtr<ParsedCommandLineData> outData = new ParsedCommandLineData(); - RefPtr<EndToEndCompileRequest> tempReq = new EndToEndCompileRequest(this); - tempReq->processCommandLineArguments(argv, argc); - outData->options.setCount(1 + tempReq->getLinkage()->targets.getCount()); - int optionDataIndex = 0; - SerializedOptionsData& optionData = outData->options[optionDataIndex]; - optionDataIndex++; - tempReq->getOptionSet().serialize(&optionData); - tempReq->m_optionSetForDefaultTarget.serialize(&optionData); - for (auto target : tempReq->getLinkage()->targets) - { - slang::TargetDesc tdesc; - SerializedOptionsData& targetOptionData = outData->options[optionDataIndex]; - optionDataIndex++; - tempReq->getTargetOptionSet(target).serialize(&targetOptionData); - tdesc.compilerOptionEntryCount = (uint32_t)targetOptionData.entries.getCount(); - tdesc.compilerOptionEntries = targetOptionData.entries.getBuffer(); - outData->targets.add(tdesc); - } - outDesc->compilerOptionEntryCount = (uint32_t)optionData.entries.getCount(); - outDesc->compilerOptionEntries = optionData.entries.getBuffer(); - outDesc->targetCount = outData->targets.getCount(); - outDesc->targets = outData->targets.getBuffer(); - *outAllocation = outData.get(); - outData->addRef(); - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL -Session::getSessionDescDigest(slang::SessionDesc* sessionDesc, ISlangBlob** outBlob) -{ - ComPtr<slang::ISession> tempSession; - createSession(*sessionDesc, tempSession.writeRef()); - auto linkage = static_cast<Linkage*>(tempSession.get()); - DigestBuilder<SHA1> digestBuilder; - linkage->buildHash(digestBuilder, -1); - auto blob = digestBuilder.finalize().toBlob(); - *outBlob = blob.detach(); - return SLANG_OK; -} - Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target) { auto entryPointProfile = entryPoint->getProfile(); @@ -1375,6885 +186,4 @@ Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target) return effectiveProfile; } - -// - -Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage) - : m_session(session) - , m_retainedSession(session) - , m_sourceManager(&m_defaultSourceManager) - , m_astBuilder(astBuilder) - , m_cmdLineContext(new CommandLineContext()) - , m_stringSlicePool(StringSlicePool::Style::Default) -{ - namePool = session->getNamePool(); - - m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr); - - setFileSystem(nullptr); - - // Copy of the built in linkages modules - if (builtinLinkage) - { - for (const auto& nameToMod : builtinLinkage->mapNameToLoadedModules) - mapNameToLoadedModules.add(nameToMod); - } - - m_semanticsForReflection = new SharedSemanticsContext(this, nullptr, nullptr); -} - -SharedSemanticsContext* Linkage::getSemanticsForReflection() -{ - return m_semanticsForReflection.get(); -} - -ISlangUnknown* Linkage::getInterface(const Guid& guid) -{ - if (guid == ISlangUnknown::getTypeGuid() || guid == ISession::getTypeGuid()) - return asExternal(this); - - return nullptr; -} - -Linkage::~Linkage() -{ - // Upstream type checking cache. - if (m_typeCheckingCache) - { - auto globalSession = getSessionImpl(); - std::lock_guard<std::mutex> lock(globalSession->m_typeCheckingCacheMutex); - if (!globalSession->m_typeCheckingCache || - globalSession->getTypeCheckingCache()->resolvedOperatorOverloadCache.getCount() < - getTypeCheckingCache()->resolvedOperatorOverloadCache.getCount()) - { - globalSession->m_typeCheckingCache = m_typeCheckingCache; - getTypeCheckingCache()->version++; - } - destroyTypeCheckingCache(); - } -} - -SearchDirectoryList& Linkage::getSearchDirectories() -{ - auto list = m_optionSet.getArray(CompilerOptionName::Include); - if (list.getCount() != searchDirectoryCache.searchDirectories.getCount()) - { - searchDirectoryCache.searchDirectories.clear(); - for (auto dir : list) - searchDirectoryCache.searchDirectories.add(SearchDirectory(dir.stringValue)); - } - return searchDirectoryCache; -} - -TypeCheckingCache* Linkage::getTypeCheckingCache() -{ - if (!m_typeCheckingCache) - { - m_typeCheckingCache = new TypeCheckingCache(); - } - return static_cast<TypeCheckingCache*>(m_typeCheckingCache.get()); -} - -void Linkage::destroyTypeCheckingCache() -{ - m_typeCheckingCache = nullptr; -} - -SLANG_NO_THROW slang::IGlobalSession* SLANG_MCALL Linkage::getGlobalSession() -{ - return asExternal(getSessionImpl()); -} - -void Linkage::addTarget(slang::TargetDesc const& desc) -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - auto targetIndex = addTarget(CodeGenTarget(desc.format)); - auto target = targets[targetIndex]; - - auto& optionSet = target->getOptionSet(); - optionSet.inheritFrom(m_optionSet); - - optionSet.set(CompilerOptionName::FloatingPointMode, FloatingPointMode(desc.floatingPointMode)); - optionSet.addTargetFlags(desc.flags); - optionSet.setProfile(Profile(desc.profile)); - optionSet.set(CompilerOptionName::LineDirectiveMode, LineDirectiveMode(desc.lineDirectiveMode)); - optionSet.set(CompilerOptionName::GLSLForceScalarLayout, desc.forceGLSLScalarBufferLayout); - - CompilerOptionSet targetOptions; - targetOptions.load(desc.compilerOptionEntryCount, desc.compilerOptionEntries); - optionSet.overrideWith(targetOptions); -} - -#if 0 - SLANG_NO_THROW SlangInt SLANG_MCALL Linkage::getTargetCount() - { - return targets.getCount(); - } - - SLANG_NO_THROW slang::ITarget* SLANG_MCALL Linkage::getTargetByIndex(SlangInt index) - { - if (index < 0) return nullptr; - if (index >= targets.getCount()) return nullptr; - return asExternal(targets[index]); - } -#endif - -static void outputExceptionDiagnostic( - const AbortCompilationException& exception, - DiagnosticSink& sink, - slang::IBlob** outDiagnostics) -{ - sink.diagnoseRaw(Severity::Error, exception.Message.getUnownedSlice()); - sink.getBlobIfNeeded(outDiagnostics); -} - -static void outputExceptionDiagnostic( - const Exception& exception, - DiagnosticSink& sink, - slang::IBlob** outDiagnostics) -{ - try - { - sink.diagnoseRaw(Severity::Internal, exception.Message.getUnownedSlice()); - } - catch (const AbortCompilationException&) - { - // Catch and ignore the AbortCompilationException that diagnoseRaw throws - // for Internal severity to prevent exception leak from loadModule - } - sink.getBlobIfNeeded(outDiagnostics); -} - -static void outputExceptionDiagnostic(DiagnosticSink& sink, slang::IBlob** outDiagnostics) -{ - try - { - sink.diagnoseRaw(Severity::Fatal, "An unknown exception occurred"); - } - catch (const AbortCompilationException&) - { - // Catch and ignore the AbortCompilationException that diagnoseRaw throws - // for Fatal severity to prevent exception leak from loadModule - } - sink.getBlobIfNeeded(outDiagnostics); -} - -SLANG_NO_THROW slang::IModule* SLANG_MCALL -Linkage::loadModule(const char* moduleName, slang::IBlob** outDiagnostics) -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer); - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - - if (isInLanguageServer()) - { - sink.setFlags(DiagnosticSink::Flag::HumaneLoc | DiagnosticSink::Flag::LanguageServer); - } - - try - { - auto name = getNamePool()->getName(moduleName); - - auto module = findOrImportModule(name, SourceLoc(), &sink); - sink.getBlobIfNeeded(outDiagnostics); - - return asExternal(module); - } - catch (const AbortCompilationException& e) - { - outputExceptionDiagnostic(e, sink, outDiagnostics); - return nullptr; - } - catch (const Exception& e) - { - outputExceptionDiagnostic(e, sink, outDiagnostics); - return nullptr; - } - catch (...) - { - outputExceptionDiagnostic(sink, outDiagnostics); - return nullptr; - } -} - -slang::IModule* Linkage::loadModuleFromBlob( - const char* moduleName, - const char* path, - slang::IBlob* source, - ModuleBlobType blobType, - slang::IBlob** outDiagnostics) -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer); - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - - if (isInLanguageServer()) - { - sink.setFlags(DiagnosticSink::Flag::HumaneLoc | DiagnosticSink::Flag::LanguageServer); - } - - - try - { - auto getDigestStr = [](auto x) - { - DigestBuilder<SHA1> digestBuilder; - digestBuilder.append(x); - return digestBuilder.finalize().toString(); - }; - - String moduleNameStr = moduleName; - if (!moduleName) - moduleNameStr = getDigestStr(source); - - auto name = getNamePool()->getName(moduleNameStr); - RefPtr<LoadedModule> loadedModule; - if (mapNameToLoadedModules.tryGetValue(name, loadedModule)) - { - return loadedModule; - } - String pathStr = path; - if (pathStr.getLength() == 0) - { - // If path is empty, use a digest from source as path. - pathStr = getDigestStr(source); - } - auto pathInfo = PathInfo::makeFromString(pathStr); - if (File::exists(pathStr)) - { - String cannonicalPath; - if (SLANG_SUCCEEDED(Path::getCanonical(pathStr, cannonicalPath))) - { - pathInfo = PathInfo::makeNormal(pathStr, cannonicalPath); - } - } - RefPtr<Module> module = - loadModuleImpl(name, pathInfo, source, SourceLoc(), &sink, nullptr, blobType); - sink.getBlobIfNeeded(outDiagnostics); - return asExternal(module.get()); - } - catch (const AbortCompilationException& e) - { - outputExceptionDiagnostic(e, sink, outDiagnostics); - return nullptr; - } - catch (const Exception& e) - { - outputExceptionDiagnostic(e, sink, outDiagnostics); - return nullptr; - } - catch (...) - { - outputExceptionDiagnostic(sink, outDiagnostics); - return nullptr; - } -} - -SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSource( - const char* moduleName, - const char* path, - slang::IBlob* source, - slang::IBlob** outDiagnostics) -{ - return loadModuleFromBlob(moduleName, path, source, ModuleBlobType::Source, outDiagnostics); -} - -SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSourceString( - const char* moduleName, - const char* path, - const char* source, - slang::IBlob** outDiagnostics) -{ - auto sourceBlob = StringBlob::create(UnownedStringSlice(source)); - return loadModuleFromSource(moduleName, path, sourceBlob.get(), outDiagnostics); -} - -SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromIRBlob( - const char* moduleName, - const char* path, - slang::IBlob* source, - slang::IBlob** outDiagnostics) -{ - return loadModuleFromBlob(moduleName, path, source, ModuleBlobType::IR, outDiagnostics); -} - -SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::loadModuleInfoFromIRBlob( - slang::IBlob* source, - SlangInt& outModuleVersion, - const char*& outModuleCompilerVersion, - const char*& outModuleName) -{ - // We start by reading the content of the file as - // an in-memory RIFF container. - // - auto rootChunk = RIFF::RootChunk::getFromBlob(source); - if (!rootChunk) - { - return SLANG_FAIL; - } - - auto moduleChunk = ModuleChunk::find(rootChunk); - if (!moduleChunk) - { - return SLANG_FAIL; - } - - auto irChunk = moduleChunk->findIR(); - if (!irChunk) - { - return SLANG_FAIL; - } - - RefPtr<IRModule> irModule; - String compilerVersion; - UInt version; - String name; - SLANG_RETURN_ON_FAIL(readSerializedModuleInfo(irChunk, compilerVersion, version, name)); - const auto compilerVersionSlice = m_stringSlicePool.addAndGetSlice(compilerVersion); - const auto nameSlice = m_stringSlicePool.addAndGetSlice(name); - outModuleCompilerVersion = compilerVersionSlice.begin(); - outModuleName = nameSlice.begin(); - outModuleVersion = SlangInt(version); - - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompositeComponentType( - slang::IComponentType* const* componentTypes, - SlangInt componentTypeCount, - slang::IComponentType** outCompositeComponentType, - ISlangBlob** outDiagnostics) -{ - if (outCompositeComponentType == nullptr) - return SLANG_E_INVALID_ARG; - - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - // Attempting to create a "composite" of just one component type should - // just return the component type itself, to avoid redundant work. - // - if (componentTypeCount == 1) - { - auto componentType = componentTypes[0]; - componentType->addRef(); - *outCompositeComponentType = componentType; - return SLANG_OK; - } - - DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer); - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - - List<RefPtr<ComponentType>> childComponents; - for (Int cc = 0; cc < componentTypeCount; ++cc) - { - childComponents.add(asInternal(componentTypes[cc])); - } - - RefPtr<ComponentType> composite = CompositeComponentType::create(this, childComponents); - - sink.getBlobIfNeeded(outDiagnostics); - - *outCompositeComponentType = asExternal(composite.detach()); - return SLANG_OK; -} - -SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType( - slang::TypeReflection* inUnspecializedType, - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - ISlangBlob** outDiagnostics) -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - auto unspecializedType = asInternal(inUnspecializedType); - - List<Type*> typeArgs; - - for (Int ii = 0; ii < specializationArgCount; ++ii) - { - auto& arg = specializationArgs[ii]; - if (arg.kind != slang::SpecializationArg::Kind::Type) - return nullptr; - - typeArgs.add(asInternal(arg.type)); - } - - DiagnosticSink sink(getSourceManager(), Lexer::sourceLocationLexer); - auto specializedType = - specializeType(unspecializedType, typeArgs.getCount(), typeArgs.getBuffer(), &sink); - sink.getBlobIfNeeded(outDiagnostics); - - return asExternal(specializedType); -} - -DeclRef<GenericDecl> getGenericParentDeclRef( - ASTBuilder* astBuilder, - SemanticsVisitor* visitor, - DeclRef<Decl> declRef) -{ - // Create substituted parent decl ref. - auto decl = declRef.getDecl(); - - while (decl && !as<GenericDecl>(decl)) - { - decl = decl->parentDecl; - } - - if (!decl) - { - // No generic parent - return DeclRef<GenericDecl>(); - } - - auto genericDecl = as<GenericDecl>(decl); - auto genericDeclRef = - createDefaultSubstitutionsIfNeeded(astBuilder, visitor, DeclRef(genericDecl)) - .as<GenericDecl>(); - return substituteDeclRef(SubstitutionSet(declRef), astBuilder, genericDeclRef) - .as<GenericDecl>(); -} - -bool Linkage::isSpecialized(DeclRef<Decl> declRef) -{ - // For now, we only support two 'states': fully applied or not at all. - // If we add support for partial specialization, we will need to update this logic. - // - // If it's not specialized, then declRef will be the one with default substitutions. - // - SemanticsVisitor visitor(getSemanticsForReflection()); - - auto decl = declRef.getDecl(); - while (decl && !as<GenericDecl>(decl)) - { - decl = decl->parentDecl; - } - - if (!decl) - return true; // no generics => always specialized - - auto defaultArgs = getDefaultSubstitutionArgs(getASTBuilder(), &visitor, as<GenericDecl>(decl)); - auto currentArgs = - SubstitutionSet(declRef).findGenericAppDeclRef(as<GenericDecl>(decl))->getArgs(); - - if (defaultArgs.getCount() != currentArgs.getCount()) // should really never happen. - return true; - - for (Index i = 0; i < defaultArgs.getCount(); ++i) - { - if (defaultArgs[i] != currentArgs[i]) - return true; - } - - return false; -} - -bool isFuncGeneric(DeclRef<Decl> declRef) -{ - if (auto funcDecl = as<FuncDecl>(declRef.getDecl())) - { - if (funcDecl->parentDecl && as<GenericDecl>(funcDecl->parentDecl)) - { - return true; - } - } - - return false; -} - -DeclRef<Decl> Linkage::specializeWithArgTypes( - Expr* funcExpr, - List<Type*> argTypes, - DiagnosticSink* sink) -{ - SemanticsVisitor visitor(getSemanticsForReflection()); - SemanticsVisitor::ExprLocalScope scope; - visitor = visitor.withSink(sink).withExprLocalScope(&scope); - - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - if (auto declRefFuncExpr = as<DeclRefExpr>(funcExpr)) - { - if (isFuncGeneric(declRefFuncExpr->declRef) && !isSpecialized(declRefFuncExpr->declRef)) - { - if (auto genericDeclRef = getGenericParentDeclRef( - getCurrentASTBuilder(), - &visitor, - declRefFuncExpr->declRef)) - { - auto genericDeclRefExpr = getCurrentASTBuilder()->create<DeclRefExpr>(); - genericDeclRefExpr->declRef = genericDeclRef; - funcExpr = genericDeclRefExpr; - } - } - } - - List<Expr*> argExprs; - for (SlangInt aa = 0; aa < argTypes.getCount(); ++aa) - { - auto argType = argTypes[aa]; - - // Create an 'empty' expr with the given type. Ideally, the expression itself should not - // matter only its checked type. - // - auto argExpr = getCurrentASTBuilder()->create<VarExpr>(); - argExpr->type = argType; - argExpr->type.isLeftValue = true; - argExprs.add(argExpr); - } - - // Construct invoke expr. - auto invokeExpr = getCurrentASTBuilder()->create<InvokeExpr>(); - invokeExpr->functionExpr = funcExpr; - invokeExpr->arguments = argExprs; - - auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr); - - return as<DeclRefExpr>(as<InvokeExpr>(checkedInvokeExpr)->functionExpr)->declRef; -} - - -DeclRef<Decl> Linkage::specializeGeneric( - DeclRef<Decl> declRef, - List<Expr*> argExprs, - DiagnosticSink* sink) -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - SLANG_ASSERT(declRef); - - SemanticsVisitor visitor(getSemanticsForReflection()); - visitor = visitor.withSink(sink); - - auto genericDeclRef = getGenericParentDeclRef(getASTBuilder(), &visitor, declRef); - - DeclRefExpr* declRefExpr = getASTBuilder()->create<DeclRefExpr>(); - declRefExpr->declRef = genericDeclRef; - - GenericAppExpr* genericAppExpr = getASTBuilder()->create<GenericAppExpr>(); - genericAppExpr->functionExpr = declRefExpr; - genericAppExpr->arguments = argExprs; - - auto specializedDeclRef = - as<DeclRefExpr>(visitor.checkGenericAppWithCheckedArgs(genericAppExpr))->declRef; - - return specializedDeclRef; -} - -SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL Linkage::getTypeLayout( - slang::TypeReflection* inType, - SlangInt targetIndex, - slang::LayoutRules rules, - ISlangBlob** outDiagnostics) -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - auto type = asInternal(inType); - - if (targetIndex < 0 || targetIndex >= targets.getCount()) - return nullptr; - - auto target = targets[targetIndex]; - - // TODO: We need a way to pass through the layout rules - // that the user requested (e.g., constant buffers vs. - // structured buffer rules). Right now the API only - // exposes a single case, so this isn't a big deal. - // - SLANG_UNUSED(rules); - - auto typeLayout = target->getTypeLayout(type, rules); - - // TODO: We currently don't have a path for capturing - // errors that occur during layout (e.g., types that - // are invalid because of target-specific layout constraints). - // - SLANG_UNUSED(outDiagnostics); - - return asExternal(typeLayout); -} - -SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getContainerType( - slang::TypeReflection* inType, - slang::ContainerType containerType, - ISlangBlob** outDiagnostics) -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - auto type = asInternal(inType); - - Type* containerTypeReflection = nullptr; - ContainerTypeKey key = {inType, containerType}; - if (!m_containerTypes.tryGetValue(key, containerTypeReflection)) - { - switch (containerType) - { - case slang::ContainerType::ConstantBuffer: - { - SemanticsVisitor visitor(getSemanticsForReflection()); - auto layoutType = getASTBuilder()->getDefaultLayoutType(); - Type* cbType = visitor.getConstantBufferType(type, layoutType); - containerTypeReflection = cbType; - } - break; - case slang::ContainerType::ParameterBlock: - { - ParameterBlockType* pbType = getASTBuilder()->getParameterBlockType(type); - containerTypeReflection = pbType; - } - break; - case slang::ContainerType::StructuredBuffer: - { - HLSLStructuredBufferType* sbType = getASTBuilder()->getStructuredBufferType(type); - containerTypeReflection = sbType; - } - break; - case slang::ContainerType::UnsizedArray: - { - ArrayExpressionType* arrType = getASTBuilder()->getArrayType(type, nullptr); - containerTypeReflection = arrType; - } - break; - default: - containerTypeReflection = type; - break; - } - - m_containerTypes.add(key, containerTypeReflection); - } - - SLANG_UNUSED(outDiagnostics); - - return asExternal(containerTypeReflection); -} - -SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::getDynamicType() -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - return asExternal(getASTBuilder()->getSharedASTBuilder()->getDynamicType()); -} - -SLANG_NO_THROW SlangResult SLANG_MCALL -Linkage::getTypeRTTIMangledName(slang::TypeReflection* type, ISlangBlob** outNameBlob) -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - auto internalType = asInternal(type); - if (auto declRefType = as<DeclRefType>(internalType)) - { - auto name = getMangledName(m_astBuilder, declRefType->getDeclRef()); - Slang::ComPtr<ISlangBlob> blob = Slang::StringUtil::createStringBlob(name); - *outNameBlob = blob.detach(); - return SLANG_OK; - } - return SLANG_FAIL; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessMangledName( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - ISlangBlob** outNameBlob) -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - auto subType = asInternal(type); - auto supType = asInternal(interfaceType); - auto name = getMangledNameForConformanceWitness(m_astBuilder, subType, supType); - Slang::ComPtr<ISlangBlob> blob = Slang::StringUtil::createStringBlob(name); - *outNameBlob = blob.detach(); - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessSequentialID( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - uint32_t* outId) -{ - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - auto subType = asInternal(type); - auto supType = asInternal(interfaceType); - - if (!subType || !supType) - return SLANG_FAIL; - - auto name = getMangledNameForConformanceWitness(m_astBuilder, subType, supType); - auto interfaceName = getMangledTypeName(m_astBuilder, supType); - uint32_t resultIndex = 0; - if (mapMangledNameToRTTIObjectIndex.tryGetValue(name, resultIndex)) - { - if (outId) - *outId = resultIndex; - return SLANG_OK; - } - auto idAllocator = mapInterfaceMangledNameToSequentialIDCounters.tryGetValue(interfaceName); - if (!idAllocator) - { - mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0; - idAllocator = mapInterfaceMangledNameToSequentialIDCounters.tryGetValue(interfaceName); - } - resultIndex = (*idAllocator); - ++(*idAllocator); - mapMangledNameToRTTIObjectIndex[name] = resultIndex; - if (outId) - *outId = resultIndex; - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getDynamicObjectRTTIBytes( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - uint32_t* outBuffer, - uint32_t bufferSize) -{ - // Slang RTTI header format: - // byte 0-7: pointer to RTTI struct describing the type. (not used for now, set to 1 for valid - // types, and 0 to represent null). - // byte 8-11: 32-bit sequential ID of the type conformance witness. - // byte 12-15: unused. - - if (bufferSize < 16) - return SLANG_E_BUFFER_TOO_SMALL; - - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - SLANG_RETURN_ON_FAIL(getTypeConformanceWitnessSequentialID(type, interfaceType, outBuffer + 2)); - - // Make the RTTI part non zero. - outBuffer[0] = 1; - - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentType( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - slang::ITypeConformance** outConformanceComponentType, - SlangInt conformanceIdOverride, - ISlangBlob** outDiagnostics) -{ - if (outConformanceComponentType == nullptr) - return SLANG_E_INVALID_ARG; - - SLANG_AST_BUILDER_RAII(getASTBuilder()); - - RefPtr<TypeConformance> result; - DiagnosticSink sink; - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - - try - { - SemanticsVisitor visitor(getSemanticsForReflection()); - visitor = visitor.withSink(&sink); - - auto witness = visitor.isSubtype( - (Slang::Type*)type, - (Slang::Type*)interfaceType, - IsSubTypeOptions::None); - if (auto subtypeWitness = as<SubtypeWitness>(witness)) - { - result = new TypeConformance(this, subtypeWitness, conformanceIdOverride, &sink); - } - } - catch (...) - { - } - sink.getBlobIfNeeded(outDiagnostics); - bool success = (result != nullptr); - *outConformanceComponentType = result.detach(); - return success ? SLANG_OK : SLANG_FAIL; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL -Linkage::createCompileRequest(SlangCompileRequest** outCompileRequest) -{ - auto compileRequest = new EndToEndCompileRequest(this); - compileRequest->addRef(); - *outCompileRequest = asExternal(compileRequest); - return SLANG_OK; -} - -SLANG_NO_THROW SlangInt SLANG_MCALL Linkage::getLoadedModuleCount() -{ - return loadedModulesList.getCount(); -} - -SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::getLoadedModule(SlangInt index) -{ - if (index >= 0 && index < loadedModulesList.getCount()) - return loadedModulesList[index].get(); - return nullptr; -} - -void Linkage::buildHash(DigestBuilder<SHA1>& builder, SlangInt targetIndex) -{ - // Add the Slang compiler version to the hash - auto version = String(getBuildTagString()); - builder.append(version); - - // Add compiler options, including search path, preprocessor includes, etc. - m_optionSet.buildHash(builder); - - auto addTargetDigest = [&](TargetRequest* targetReq) - { - targetReq->getOptionSet().buildHash(builder); - - const PassThroughMode passThroughMode = - getDownstreamCompilerRequiredForTarget(targetReq->getTarget()); - const SourceLanguage sourceLanguage = - getDefaultSourceLanguageForDownstreamCompiler(passThroughMode); - - // Add prelude for the given downstream compiler. - ComPtr<ISlangBlob> prelude; - getGlobalSession()->getLanguagePrelude( - (SlangSourceLanguage)sourceLanguage, - prelude.writeRef()); - if (prelude) - { - builder.append(prelude); - } - - // TODO: Downstream compilers (specifically dxc) can currently #include additional - // dependencies. This is currently the case for NVAPI headers included in the prelude. These - // dependencies are currently not picked up by the shader cache which is a significant - // issue. This can only be fixed by running the preprocessor in the slang compiler so dxc - // (or any other downstream compiler for that matter) isn't resolving any includes - // implicitly. - - // Add the downstream compiler version (if it exists) to the hash - auto downstreamCompiler = - getSessionImpl()->getOrLoadDownstreamCompiler(passThroughMode, nullptr); - if (downstreamCompiler) - { - ComPtr<ISlangBlob> versionString; - if (SLANG_SUCCEEDED(downstreamCompiler->getVersionString(versionString.writeRef()))) - { - builder.append(versionString); - } - } - }; - - // Add the target specified by targetIndex - if (targetIndex == -1) - { - // -1 means all targets. - for (auto targetReq : targets) - { - addTargetDigest(targetReq); - } - } - else - { - auto targetReq = targets[targetIndex]; - addTargetDigest(targetReq); - } -} - -SlangResult Linkage::addSearchPath(char const* path) -{ - m_optionSet.add(CompilerOptionName::Include, String(path)); - return SLANG_OK; -} - -SlangResult Linkage::addPreprocessorDefine(char const* name, char const* value) -{ - CompilerOptionValue val; - val.kind = CompilerOptionValueKind::String; - val.stringValue = name; - val.stringValue2 = value; - m_optionSet.add(CompilerOptionName::MacroDefine, val); - return SLANG_OK; -} - -SlangResult Linkage::setMatrixLayoutMode(SlangMatrixLayoutMode mode) -{ - m_optionSet.setMatrixLayoutMode((MatrixLayoutMode)mode); - return SLANG_OK; -} - -// -// TargetRequest -// - -TargetRequest::TargetRequest(Linkage* linkage, CodeGenTarget format) - : linkage(linkage) -{ - optionSet = linkage->m_optionSet; - optionSet.add(CompilerOptionName::Target, format); -} - -TargetRequest::TargetRequest(const TargetRequest& other) - : RefObject(), linkage(other.linkage), optionSet(other.optionSet) -{ -} - - -Session* TargetRequest::getSession() -{ - return linkage->getSessionImpl(); -} - -HLSLToVulkanLayoutOptions* TargetRequest::getHLSLToVulkanLayoutOptions() -{ - if (!hlslToVulkanOptions) - { - hlslToVulkanOptions = new HLSLToVulkanLayoutOptions(); - hlslToVulkanOptions->loadFromOptionSet(optionSet); - } - return hlslToVulkanOptions.get(); -} - -void TargetRequest::setTargetCaps(CapabilitySet capSet) -{ - cookedCapabilities = capSet; -} - -CapabilitySet TargetRequest::getTargetCaps() -{ - if (!cookedCapabilities.isEmpty()) - return cookedCapabilities; - - // The full `CapabilitySet` for the target will be computed - // from the combination of the code generation format, and - // the profile. - // - // Note: the preofile might have been set in a way that is - // inconsistent with the output code format of SPIR-V, but - // a profile of Direct3D Shader Model 5.1. In those cases, - // the format should always override the implications in - // the profile. - // - // TODO: This logic isn't currently taking int account - // the information in the profile, because the current - // `CapabilityAtom`s that we support don't include any - // of the details there (e.g., the shader model versions). - // - // Eventually, we'd want to have a rich set of capability - // atoms, so that most of the information about what operations - // are available where can be directly encoded on the declarations. - - List<CapabilityName> atoms; - - // If the user specified a explicit profile, we should pull - // a corresponding atom representing the target version from the profile. - CapabilitySet profileCaps = optionSet.getProfile().getCapabilityName(); - - bool isGLSLTarget = false; - switch (getTarget()) - { - case CodeGenTarget::GLSL: - isGLSLTarget = true; - atoms.add(CapabilityName::glsl); - break; - case CodeGenTarget::SPIRV: - case CodeGenTarget::SPIRVAssembly: - if (getOptionSet().shouldEmitSPIRVDirectly()) - { - // Default to SPIRV 1.5 if the user has not specified a target version. - bool hasTargetVersionAtom = false; - if (!profileCaps.isEmpty()) - { - profileCaps.join(CapabilitySet(CapabilityName::spirv_1_0)); - for (auto profileCapAtomSet : profileCaps.getAtomSets()) - { - for (auto atom : profileCapAtomSet) - { - if (isTargetVersionAtom(asAtom(atom))) - { - atoms.add((CapabilityName)atom); - hasTargetVersionAtom = true; - } - } - } - } - if (!hasTargetVersionAtom) - { - atoms.add(CapabilityName::spirv_1_5); - } - // If the user specified any SPIR-V extensions in the profile, - // pull them in. - for (auto profileCapAtomSet : profileCaps.getAtomSets()) - { - for (auto atom : profileCapAtomSet) - { - if (isSpirvExtensionAtom(asAtom(atom))) - { - atoms.add((CapabilityName)atom); - hasTargetVersionAtom = true; - } - } - } - } - else - { - isGLSLTarget = true; - atoms.add(CapabilityName::glsl); - profileCaps.addSpirvVersionFromOtherAsGlslSpirvVersion(profileCaps); - } - break; - - case CodeGenTarget::HLSL: - case CodeGenTarget::DXBytecode: - case CodeGenTarget::DXBytecodeAssembly: - case CodeGenTarget::DXIL: - case CodeGenTarget::DXILAssembly: - atoms.add(CapabilityName::hlsl); - break; - - case CodeGenTarget::CSource: - atoms.add(CapabilityName::c); - break; - - case CodeGenTarget::CPPSource: - case CodeGenTarget::PyTorchCppBinding: - case CodeGenTarget::HostExecutable: - case CodeGenTarget::ShaderSharedLibrary: - case CodeGenTarget::HostSharedLibrary: - case CodeGenTarget::HostHostCallable: - case CodeGenTarget::ShaderHostCallable: - atoms.add(CapabilityName::cpp); - break; - - case CodeGenTarget::CUDASource: - case CodeGenTarget::PTX: - atoms.add(CapabilityName::cuda); - break; - - case CodeGenTarget::Metal: - case CodeGenTarget::MetalLib: - case CodeGenTarget::MetalLibAssembly: - atoms.add(CapabilityName::metal); - break; - - case CodeGenTarget::WGSLSPIRV: - case CodeGenTarget::WGSLSPIRVAssembly: - case CodeGenTarget::WGSL: - atoms.add(CapabilityName::wgsl); - break; - - default: - break; - } - - CapabilitySet targetCap = CapabilitySet(atoms); - - if (profileCaps.atLeastOneSetImpliedInOther(targetCap) == - CapabilitySet::ImpliesReturnFlags::Implied) - targetCap.join(profileCaps); - - for (auto atomVal : optionSet.getArray(CompilerOptionName::Capability)) - { - CapabilitySet toAdd; - switch (atomVal.kind) - { - case CompilerOptionValueKind::Int: - toAdd = CapabilitySet(CapabilityName(atomVal.intValue)); - break; - case CompilerOptionValueKind::String: - toAdd = CapabilitySet(findCapabilityName(atomVal.stringValue.getUnownedSlice())); - break; - } - - if (isGLSLTarget) - targetCap.addSpirvVersionFromOtherAsGlslSpirvVersion(toAdd); - - if (!targetCap.isIncompatibleWith(toAdd)) - targetCap.join(toAdd); - } - - cookedCapabilities = targetCap; - - SLANG_ASSERT(!cookedCapabilities.isInvalid()); - - return cookedCapabilities; -} - - -TypeLayout* TargetRequest::getTypeLayout(Type* type, slang::LayoutRules rules) -{ - SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder()); - - // TODO: We are not passing in a `ProgramLayout` here, although one - // is nominally required to establish the global ordering of - // generic type parameters, which might be referenced from field types. - // - // The solution here is to make sure that the reflection data for - // uses of global generic/existential types does *not* include any - // kind of index in that global ordering, and just refers to the - // parameter instead (leaving the user to figure out how that - // maps to the ordering via some API on the program layout). - // - auto layoutContext = getInitialLayoutContextForTarget(this, nullptr, rules); - - RefPtr<TypeLayout> result; - auto key = TypeLayoutKey{type, rules}; - if (getTypeLayouts().tryGetValue(key, result)) - return result.Ptr(); - result = createTypeLayout(layoutContext, type); - getTypeLayouts()[key] = result; - return result.Ptr(); -} - -// -// TranslationUnitRequest -// - -TranslationUnitRequest::TranslationUnitRequest(FrontEndCompileRequest* compileRequest) - : compileRequest(compileRequest) -{ - module = new Module(compileRequest->getLinkage()); -} - -TranslationUnitRequest::TranslationUnitRequest(FrontEndCompileRequest* compileRequest, Module* m) - : compileRequest(compileRequest), module(m), isChecked(true) -{ - moduleName = getNamePool()->getName(m->getName()); -} - -Session* TranslationUnitRequest::getSession() -{ - return compileRequest->getSession(); -} - -NamePool* TranslationUnitRequest::getNamePool() -{ - return compileRequest->getNamePool(); -} - -SourceManager* TranslationUnitRequest::getSourceManager() -{ - return compileRequest->getSourceManager(); -} - -Scope* TranslationUnitRequest::getLanguageScope() -{ - Scope* languageScope = nullptr; - switch (sourceLanguage) - { - case SourceLanguage::HLSL: - languageScope = getSession()->hlslLanguageScope; - break; - case SourceLanguage::GLSL: - languageScope = getSession()->glslLanguageScope; - break; - case SourceLanguage::Slang: - default: - languageScope = getSession()->slangLanguageScope; - break; - } - return languageScope; -} - -Dictionary<String, String> TranslationUnitRequest::getCombinedPreprocessorDefinitions() -{ - Dictionary<String, String> combinedPreprocessorDefinitions; - for (const auto& def : preprocessorDefinitions) - combinedPreprocessorDefinitions.addIfNotExists(def); - for (const auto& def : compileRequest->optionSet.getArray(CompilerOptionName::MacroDefine)) - combinedPreprocessorDefinitions.addIfNotExists(def.stringValue, def.stringValue2); - - // Define standard macros, if not already defined. This style assumes using `#if __SOME_VAR` - // style, as in - // - // ``` - // #if __SLANG_COMPILER__ - // ``` - // - // This choice is made because slang outputs a warning on using a variable in an #if if not - // defined - // - // Of course this means using #ifndef/#ifdef/defined() is probably not appropraite with thes - // variables. - { - // Used to identify level of HLSL language compatibility - combinedPreprocessorDefinitions.addIfNotExists("__HLSL_VERSION", "2018"); - - // Indicates this is being compiled by the slang *compiler* - combinedPreprocessorDefinitions.addIfNotExists("__SLANG_COMPILER__", "1"); - - // Set macro depending on source type - switch (sourceLanguage) - { - case SourceLanguage::HLSL: - // Used to indicate compiled as HLSL language - combinedPreprocessorDefinitions.addIfNotExists("__HLSL__", "1"); - break; - case SourceLanguage::Slang: - // Used to indicate compiled as Slang language - combinedPreprocessorDefinitions.addIfNotExists("__SLANG__", "1"); - break; - default: - break; - } - - // If not set, define as 0. - combinedPreprocessorDefinitions.addIfNotExists("__HLSL__", "0"); - combinedPreprocessorDefinitions.addIfNotExists("__SLANG__", "0"); - } - - return combinedPreprocessorDefinitions; -} - -void TranslationUnitRequest::addSourceArtifact(IArtifact* sourceArtifact) -{ - SLANG_ASSERT(sourceArtifact); - m_sourceArtifacts.add(ComPtr<IArtifact>(sourceArtifact)); -} - - -void TranslationUnitRequest::addSource(IArtifact* sourceArtifact, SourceFile* sourceFile) -{ - SLANG_ASSERT(sourceArtifact && sourceFile); - // Must be in sync! - SLANG_ASSERT(m_sourceFiles.getCount() == m_sourceArtifacts.getCount()); - - addSourceArtifact(sourceArtifact); - _addSourceFile(sourceFile); -} - -void TranslationUnitRequest::addIncludedSourceFileIfNotExist(SourceFile* sourceFile) -{ - if (m_includedFileSet.contains(sourceFile)) - return; - - sourceFile->setIncludedFile(); - m_sourceFiles.add(sourceFile); - m_includedFileSet.add(sourceFile); -} - -PathInfo TranslationUnitRequest::_findSourcePathInfo(IArtifact* artifact) -{ - auto pathRep = findRepresentation<IPathArtifactRepresentation>(artifact); - - if (pathRep && pathRep->getPathType() == SLANG_PATH_TYPE_FILE) - { - // See if we have a unique identity set with the path - if (const auto uniqueIdentity = pathRep->getUniqueIdentity()) - { - return PathInfo::makeNormal(pathRep->getPath(), uniqueIdentity); - } - - // If we couldn't get a unique identity, just use the path - return PathInfo::makePath(pathRep->getPath()); - } - - // If there isn't a path, we can try with the name - const char* name = artifact->getName(); - if (name && name[0] != 0) - { - return PathInfo::makeFromString(name); - } - - return PathInfo::makeUnknown(); -} - -SlangResult TranslationUnitRequest::requireSourceFiles() -{ - SLANG_ASSERT(m_sourceFiles.getCount() <= m_sourceArtifacts.getCount()); - - if (m_sourceFiles.getCount() == m_sourceArtifacts.getCount()) - { - return SLANG_OK; - } - - auto sink = compileRequest->getSink(); - SourceManager* sourceManager = compileRequest->getSourceManager(); - - for (Index i = m_sourceFiles.getCount(); i < m_sourceArtifacts.getCount(); ++i) - { - IArtifact* artifact = m_sourceArtifacts[i]; - - const PathInfo pathInfo = _findSourcePathInfo(artifact); - - SourceFile* sourceFile = nullptr; - ComPtr<ISlangBlob> blob; - - // If we have a unique identity see if we have it already - if (pathInfo.hasUniqueIdentity()) - { - // See if this an already loaded source file - sourceFile = sourceManager->findSourceFileRecursively(pathInfo.uniqueIdentity); - // If we have a sourceFile see if it has a blob - if (sourceFile) - { - blob = sourceFile->getContentBlob(); - } - } - - // If we *don't* have a blob try and get a blob from the artifact - if (!blob) - { - const SlangResult res = artifact->loadBlob(ArtifactKeep::Yes, blob.writeRef()); - if (SLANG_FAILED(res)) - { - // Report couldn't load - sink->diagnose(SourceLoc(), Diagnostics::cannotOpenFile, pathInfo.getName()); - return res; - } - } - - // If we don't have a blob on the artifact we can now add the one we have - if (!findRepresentation<ISlangBlob>(artifact)) - { - artifact->addRepresentationUnknown(blob); - } - - // If we have a sourceFile check if it has contents, and set the blob if doesn't - if (sourceFile) - { - if (!sourceFile->getContentBlob()) - { - sourceFile->setContents(blob); - } - } - else - { - // Create a new source file, using the pathInfo and blob - sourceFile = sourceManager->createSourceFileWithBlob(pathInfo, blob); - } - - auto uniqueIdentity = pathInfo.getMostUniqueIdentity(); - if (uniqueIdentity.getLength()) - sourceManager->addSourceFileIfNotExist(uniqueIdentity, sourceFile); - - // Finally add the source file - _addSourceFile(sourceFile); - } - - return SLANG_OK; -} - -void TranslationUnitRequest::_addSourceFile(SourceFile* sourceFile) -{ - m_sourceFiles.add(sourceFile); - - getModule()->addFileDependency(sourceFile); - getModule()->getIncludedSourceFileMap().add(sourceFile, nullptr); -} - -List<SourceFile*> const& TranslationUnitRequest::getSourceFiles() -{ - return m_sourceFiles; -} - -EndToEndCompileRequest::~EndToEndCompileRequest() -{ - // Flush any writers associated with the request - m_writers->flushWriters(); - - m_linkage.setNull(); - m_frontEndReq.setNull(); -} - -static ISlangWriter* _getDefaultWriter(WriterChannel chan) -{ - static FileWriter stdOut(stdout, WriterFlag::IsStatic | WriterFlag::IsUnowned); - static FileWriter stdError(stderr, WriterFlag::IsStatic | WriterFlag::IsUnowned); - static NullWriter nullWriter(WriterFlag::IsStatic | WriterFlag::IsConsole); - - switch (chan) - { - case WriterChannel::StdError: - return &stdError; - case WriterChannel::StdOutput: - return &stdOut; - case WriterChannel::Diagnostic: - return &nullWriter; - default: - { - SLANG_ASSERT(!"Unknown type"); - return &stdError; - } - } -} - -void EndToEndCompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) -{ - // If the user passed in null, we will use the default writer on that channel - m_writers->setWriter(SlangWriterChannel(chan), writer ? writer : _getDefaultWriter(chan)); - - // For diagnostic output, if the user passes in nullptr, we set on m_sink.writer as that enables - // buffering on DiagnosticSink - if (chan == WriterChannel::Diagnostic) - { - m_sink.writer = writer; - } -} - -SlangResult Linkage::loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob) -{ - outPathInfo.type = PathInfo::Type::Unknown; - - SLANG_RETURN_ON_FAIL(m_fileSystemExt->loadFile(path.getBuffer(), outBlob)); - - ComPtr<ISlangBlob> uniqueIdentity; - // Get the unique identity - if (SLANG_FAILED( - m_fileSystemExt->getFileUniqueIdentity(path.getBuffer(), uniqueIdentity.writeRef()))) - { - // We didn't get a unique identity, so go with just a found path - outPathInfo.type = PathInfo::Type::FoundPath; - outPathInfo.foundPath = path; - } - else - { - outPathInfo = PathInfo::makeNormal(path, StringUtil::getString(uniqueIdentity)); - } - return SLANG_OK; -} - -Expr* Linkage::parseTermString(String typeStr, Scope* scope) -{ - // Create a SourceManager on the stack, so any allocations for 'SourceFile'/'SourceView' etc - // will be cleaned up - SourceManager localSourceManager; - localSourceManager.initialize(getSourceManager(), nullptr); - - Slang::SourceFile* srcFile = - localSourceManager.createSourceFileWithString(PathInfo::makeTypeParse(), typeStr); - - // We'll use a temporary diagnostic sink - DiagnosticSink sink(&localSourceManager, nullptr); - - // RAII type to make make sure current SourceManager is restored after parse. - // Use RAII - to make sure everything is reset even if an exception is thrown. - struct ScopeReplaceSourceManager - { - ScopeReplaceSourceManager(Linkage* linkage, SourceManager* replaceManager) - : m_linkage(linkage), m_originalSourceManager(linkage->getSourceManager()) - { - linkage->setSourceManager(replaceManager); - } - - ~ScopeReplaceSourceManager() { m_linkage->setSourceManager(m_originalSourceManager); } - - private: - Linkage* m_linkage; - SourceManager* m_originalSourceManager; - }; - - // We need to temporarily replace the SourceManager for this CompileRequest - ScopeReplaceSourceManager scopeReplaceSourceManager(this, &localSourceManager); - - SourceLanguage sourceLanguage = SourceLanguage::Slang; - SlangLanguageVersion languageVersion = m_optionSet.getLanguageVersion(); - - auto tokens = preprocessSource( - srcFile, - &sink, - nullptr, - Dictionary<String, String>(), - this, - sourceLanguage, - languageVersion); - - if (sourceLanguage == SourceLanguage::Unknown) - sourceLanguage = SourceLanguage::Slang; - - return parseTermFromSourceFile( - getASTBuilder(), - tokens, - &sink, - scope, - getNamePool(), - sourceLanguage); -} - -Type* checkProperType(Linkage* linkage, TypeExp typeExp, DiagnosticSink* sink); - -Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* sink) -{ - // If we've looked up this type name before, - // then we can re-use it. - // - Type* type = nullptr; - if (m_types.tryGetValue(typeStr, type)) - return type; - - - // TODO(JS): For now just used the linkages ASTBuilder to keep on scope - // - // The parseTermString uses the linkage ASTBuilder for it's parsing. - // - // It might be possible to just create a temporary ASTBuilder - the worry though is - // that the parsing sets a member variable in AST node to one of these scopes, and then - // it become a dangling pointer. So for now we go with the linkages. - auto astBuilder = getLinkage()->getASTBuilder(); - - // Otherwise, we need to start looking in - // the modules that were directly or - // indirectly referenced. - // - Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); - - auto linkage = getLinkage(); - - SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); - - Expr* typeExpr = linkage->parseTermString(typeStr, scope); - SharedSemanticsContext sharedSemanticsContext(linkage, nullptr, sink); - SemanticsVisitor visitor(&sharedSemanticsContext); - type = visitor.TranslateTypeNode(typeExpr); - auto typeOut = visitor.tryCoerceToProperType(TypeExp(type)); - if (typeOut.type) - type = typeOut.type; - - if (type) - { - m_types[typeStr] = type; - } - return type; -} - -Expr* ComponentType::findDeclFromString(String const& name, DiagnosticSink* sink) -{ - // If we've looked up this type name before, - // then we can re-use it. - // - Expr* result = nullptr; - if (m_decls.tryGetValue(name, result)) - return result; - - - // TODO(JS): For now just used the linkages ASTBuilder to keep on scope - // - // The parseTermString uses the linkage ASTBuilder for it's parsing. - // - // It might be possible to just create a temporary ASTBuilder - the worry though is - // that the parsing sets a member variable in AST node to one of these scopes, and then - // it become a dangling pointer. So for now we go with the linkages. - auto astBuilder = getLinkage()->getASTBuilder(); - - // Otherwise, we need to start looking in - // the modules that were directly or - // indirectly referenced. - // - Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); - - auto linkage = getLinkage(); - - SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); - - Expr* expr = linkage->parseTermString(name, scope); - - SemanticsContext context(linkage->getSemanticsForReflection()); - context = context.allowStaticReferenceToNonStaticMember().withSink(sink); - - SemanticsVisitor visitor(context); - - auto checkedExpr = visitor.CheckTerm(expr); - - if (as<DeclRefExpr>(checkedExpr) || as<OverloadedExpr>(checkedExpr)) - { - result = checkedExpr; - } - - m_decls[name] = result; - return result; -} - -bool isSimpleName(String const& name) -{ - for (char c : name) - { - if (!CharUtil::isAlphaOrDigit(c) && c != '_' && c != '$') - return false; - } - return true; -} - -Expr* ComponentType::findDeclFromStringInType( - Type* type, - String const& name, - LookupMask mask, - DiagnosticSink* sink) -{ - // Only look up in the type if it is a DeclRefType - if (!as<DeclRefType>(type)) - return nullptr; - - // TODO(JS): For now just used the linkages ASTBuilder to keep on scope - // - // The parseTermString uses the linkage ASTBuilder for it's parsing. - // - // It might be possible to just create a temporary ASTBuilder - the worry though is - // that the parsing sets a member variable in AST node to one of these scopes, and then - // it become a dangling pointer. So for now we go with the linkages. - auto astBuilder = getLinkage()->getASTBuilder(); - - // Otherwise, we need to start looking in - // the modules that were directly or - // indirectly referenced. - // - Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); - - auto linkage = getLinkage(); - - SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); - - Expr* expr = nullptr; - - if (isSimpleName(name)) - { - auto varExpr = astBuilder->create<VarExpr>(); - varExpr->scope = scope; - varExpr->name = getLinkage()->getNamePool()->getName(name); - expr = varExpr; - } - else - { - expr = linkage->parseTermString(name, scope); - } - SemanticsContext context(linkage->getSemanticsForReflection()); - context = context.allowStaticReferenceToNonStaticMember().withSink(sink); - - SemanticsVisitor visitor(context); - - GenericAppExpr* genericOuterExpr = nullptr; - if (as<GenericAppExpr>(expr)) - { - // Unwrap the generic application, and re-wrap it around the static-member expr - genericOuterExpr = as<GenericAppExpr>(expr); - expr = genericOuterExpr->functionExpr; - } - - if (!as<VarExpr>(expr)) - return nullptr; - - auto rs = astBuilder->create<StaticMemberExpr>(); - auto typeExpr = astBuilder->create<SharedTypeExpr>(); - auto typetype = astBuilder->getOrCreate<TypeType>(type); - typeExpr->type = typetype; - rs->baseExpression = typeExpr; - rs->name = as<VarExpr>(expr)->name; - - expr = rs; - - // If we have a generic-app expression, re-wrap the static-member expr - if (genericOuterExpr) - { - genericOuterExpr->functionExpr = expr; - expr = genericOuterExpr; - } - - auto checkedTerm = visitor.CheckTerm(expr); - - // Check if checkedTerm is overloaded functions and avoid resolving if so - // to preserve all function overloads with different signatures - Expr* resolvedTerm = checkedTerm; - if (auto overloadedExpr = as<OverloadedExpr>(checkedTerm)) - { - // Check if all candidates are function references - bool allAreFunctions = true; - for (auto item : overloadedExpr->lookupResult2.items) - { - if (!as<FunctionDeclBase>(item.declRef.getDecl())) - { - allAreFunctions = false; - break; - } - } - - // If not all are functions, resolve the overload as usual - if (!allAreFunctions) - { - resolvedTerm = visitor.maybeResolveOverloadedExpr(checkedTerm, mask, sink); - } - } - else - { - // Not overloaded, resolve as usual - resolvedTerm = visitor.maybeResolveOverloadedExpr(checkedTerm, mask, sink); - } - - - if (auto overloadedExpr = as<OverloadedExpr>(resolvedTerm)) - { - return overloadedExpr; - } - if (auto declRefExpr = as<DeclRefExpr>(resolvedTerm)) - { - return declRefExpr; - } - - return nullptr; -} - -bool ComponentType::isSubType(Type* subType, Type* superType) -{ - SemanticsContext context(getLinkage()->getSemanticsForReflection()); - SemanticsVisitor visitor(context); - - return (visitor.isSubtype(subType, superType, IsSubTypeOptions::None) != nullptr); -} - -static void collectExportedConstantInContainer( - Dictionary<String, IntVal*>& dict, - ASTBuilder* builder, - ContainerDecl* containerDecl) -{ - for (auto varMember : containerDecl->getDirectMemberDeclsOfType<VarDeclBase>()) - { - if (!varMember->val) - continue; - bool isExported = false; - bool isConst = false; - bool isExtern = false; - for (auto modifier : varMember->modifiers) - { - if (as<HLSLExportModifier>(modifier)) - isExported = true; - if (as<ExternAttribute>(modifier) || as<ExternModifier>(modifier)) - { - isExtern = true; - isExported = true; - } - if (as<ConstModifier>(modifier)) - isConst = true; - } - if (isExported && isConst) - { - auto mangledName = getMangledName(builder, varMember); - if (isExtern && dict.containsKey(mangledName)) - continue; - dict[mangledName] = varMember->val; - } - } - - for (auto member : containerDecl->getDirectMemberDecls()) - { - if (as<NamespaceDecl>(member) || as<FileDecl>(member)) - { - collectExportedConstantInContainer(dict, builder, (ContainerDecl*)member); - } - } -} - -Dictionary<String, IntVal*>& ComponentType::getMangledNameToIntValMap() -{ - if (m_mapMangledNameToIntVal) - { - return *m_mapMangledNameToIntVal; - } - m_mapMangledNameToIntVal = std::make_unique<Dictionary<String, IntVal*>>(); - auto astBuilder = getLinkage()->getASTBuilder(); - SLANG_AST_BUILDER_RAII(astBuilder); - Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); - for (; scope; scope = scope->nextSibling) - { - if (scope->containerDecl) - collectExportedConstantInContainer( - *m_mapMangledNameToIntVal, - astBuilder, - scope->containerDecl); - } - return *m_mapMangledNameToIntVal; -} - -ConstantIntVal* ComponentType::tryFoldIntVal(IntVal* intVal) -{ - auto astBuilder = getLinkage()->getASTBuilder(); - SLANG_AST_BUILDER_RAII(astBuilder); - return as<ConstantIntVal>(intVal->linkTimeResolve(getMangledNameToIntValMap())); -} - -CompileRequestBase::CompileRequestBase(Linkage* linkage, DiagnosticSink* sink) - : m_linkage(linkage), m_sink(sink) -{ -} - - -FrontEndCompileRequest::FrontEndCompileRequest( - Linkage* linkage, - StdWriters* writers, - DiagnosticSink* sink) - : CompileRequestBase(linkage, sink), m_writers(writers) -{ - optionSet.inheritFrom(linkage->m_optionSet); -} - -/// Handlers for preprocessor callbacks to use when doing ordinary front-end compilation -struct FrontEndPreprocessorHandler : PreprocessorHandler -{ -public: - FrontEndPreprocessorHandler( - Module* module, - ASTBuilder* astBuilder, - DiagnosticSink* sink, - TranslationUnitRequest* translationUnit) - : m_module(module) - , m_astBuilder(astBuilder) - , m_sink(sink) - , m_translationUnit(translationUnit) - { - } - -protected: - Module* m_module; - ASTBuilder* m_astBuilder; - DiagnosticSink* m_sink; - TranslationUnitRequest* m_translationUnit = nullptr; - - // The first task that this handler tries to deal with is - // capturing all the files on which a module is dependent. - // - // That information is exposed through public APIs and used - // by applications to decide when they need to "hot reload" - // their shader code. - // - void handleFileDependency(SourceFile* sourceFile) SLANG_OVERRIDE - { - m_module->addFileDependency(sourceFile); - m_translationUnit->addIncludedSourceFileIfNotExist(sourceFile); - } - - // The second task that this handler deals with is detecting - // whether any macro values were set in a given source file - // that are semantically relevant to other stages of compilation. - // - void handleEndOfTranslationUnit(Preprocessor* preprocessor) SLANG_OVERRIDE - { - // We look at the preprocessor state after reading the entire - // source file/string, in order to see if any macros have been - // set that should be considered semantically relevant for - // later stages of compilation. - // - // Note: Checking the macro environment *after* preprocessing is complete - // means that we can treat macros introduced via `-D` options or the API - // equivalently to macros introduced via `#define`s in user code. - // - // For now, the only case of semantically-relevant macros we need to worrry - // about are the NVAPI macros used to establish the register/space to use. - // - static const char* kNVAPIRegisterMacroName = "NV_SHADER_EXTN_SLOT"; - static const char* kNVAPISpaceMacroName = "NV_SHADER_EXTN_REGISTER_SPACE"; - - // For NVAPI use, the `NV_SHADER_EXTN_SLOT` macro is required to be defined. - // - String nvapiRegister; - SourceLoc nvapiRegisterLoc; - if (!SLANG_FAILED(findMacroValue( - preprocessor, - kNVAPIRegisterMacroName, - nvapiRegister, - nvapiRegisterLoc))) - { - // In contrast, NVAPI can be used without defining `NV_SHADER_EXTN_REGISTER_SPACE`, - // which effectively defaults to `space0`. - // - String nvapiSpace = "space0"; - SourceLoc nvapiSpaceLoc; - findMacroValue(preprocessor, kNVAPISpaceMacroName, nvapiSpace, nvapiSpaceLoc); - - // We are going to store the values of these macros on the AST-level `ModuleDecl` - // so that they will be available to later processing stages. - // - auto moduleDecl = m_module->getModuleDecl(); - - if (auto existingModifier = moduleDecl->findModifier<NVAPISlotModifier>()) - { - // If there is already a modifier attached to the module (perhaps - // because of preprocessing a different source file, or because - // of settings established via command-line options), then we - // need to validate that the values being set in this file - // match those already set (or else there is likely to be - // some kind of error in the user's code). - // - _validateNVAPIMacroMatch( - kNVAPIRegisterMacroName, - existingModifier->registerName, - nvapiRegister, - nvapiRegisterLoc); - _validateNVAPIMacroMatch( - kNVAPISpaceMacroName, - existingModifier->spaceName, - nvapiSpace, - nvapiSpaceLoc); - } - else - { - // If there is no existing modifier on the module, then we - // take responsibility for adding one, based on the macro - // values we saw. - // - auto modifier = m_astBuilder->create<NVAPISlotModifier>(); - modifier->loc = nvapiRegisterLoc; - modifier->registerName = nvapiRegister; - modifier->spaceName = nvapiSpace; - - addModifier(moduleDecl, modifier); - } - } - } - - /// Validate that a re-defintion of an NVAPI-related macro matches any previous definition - void _validateNVAPIMacroMatch( - char const* macroName, - String const& existingValue, - String const& newValue, - SourceLoc loc) - { - if (existingValue != newValue) - { - m_sink->diagnose( - loc, - Diagnostics::nvapiMacroMismatch, - macroName, - existingValue, - newValue); - } - } -}; - - -// Holds the hierarchy of views, the children being views that were 'initiated' (have an initiating -// SourceLoc) in the parent. -typedef Dictionary<SourceView*, List<SourceView*>> ViewInitiatingHierarchy; - -// Calculate the hierarchy from the sourceManager -static void _calcViewInitiatingHierarchy( - SourceManager* sourceManager, - ViewInitiatingHierarchy& outHierarchy) -{ - const List<SourceView*> emptyList; - outHierarchy.clear(); - - // Iterate over all managers - for (SourceManager* curManager = sourceManager; curManager; - curManager = curManager->getParent()) - { - // Iterate over all views - for (SourceView* view : curManager->getSourceViews()) - { - if (view->getInitiatingSourceLoc().isValid()) - { - // Look up the view it came from - SourceView* parentView = - sourceManager->findSourceViewRecursively(view->getInitiatingSourceLoc()); - if (parentView) - { - List<SourceView*>& children = outHierarchy.getOrAddValue(parentView, emptyList); - // It shouldn't have already been added - SLANG_ASSERT(children.indexOf(view) < 0); - children.add(view); - } - } - } - } - - // Order all the children, by their raw SourceLocs. This is desirable, so that a trivial - // traversal will traverse children in the order they are initiated in the parent source. This - // assumes they increase in SourceLoc implies an later within a source file - this is true - // currently. - for (auto& [_, value] : outHierarchy) - { - value.sort( - [](SourceView* a, SourceView* b) -> bool { - return a->getInitiatingSourceLoc().getRaw() < b->getInitiatingSourceLoc().getRaw(); - }); - } -} - -// Given a source file, find the view that is the initial SourceView use of the source. It must have -// an initiating SourceLoc that is not valid. -static SourceView* _findInitialSourceView(SourceFile* sourceFile) -{ - // TODO(JS): - // This might be overkill - presumably the SourceView would belong to the same manager as it's - // SourceFile? That is not enforced by the SourceManager in any way though so we just search all - // managers, and all views. - for (SourceManager* sourceManager = sourceFile->getSourceManager(); sourceManager; - sourceManager = sourceManager->getParent()) - { - for (SourceView* view : sourceManager->getSourceViews()) - { - if (view->getSourceFile() == sourceFile && !view->getInitiatingSourceLoc().isValid()) - { - return view; - } - } - } - - return nullptr; -} - -static void _outputInclude(SourceFile* sourceFile, Index depth, DiagnosticSink* sink) -{ - StringBuilder buf; - - for (Index i = 0; i < depth; ++i) - { - buf << " "; - } - - // Output the found path for now - // TODO(JS). We could use the verbose paths flag to control what path is output -> as it may be - // useful to output the full path for example - - const PathInfo& pathInfo = sourceFile->getPathInfo(); - buf << "'" << pathInfo.foundPath << "'"; - - // TODO(JS)? - // You might want to know where this include was from. - // If I output this though there will be a problem... as the indenting won't be clearly shown. - // Perhaps I output in two sections, one the hierarchy and the other the locations of the - // includes? - - sink->diagnose(SourceLoc(), Diagnostics::includeOutput, buf); -} - -static void _outputIncludesRec( - SourceView* sourceView, - Index depth, - ViewInitiatingHierarchy& hierarchy, - DiagnosticSink* sink) -{ - SourceFile* sourceFile = sourceView->getSourceFile(); - const PathInfo& pathInfo = sourceFile->getPathInfo(); - - switch (pathInfo.type) - { - case PathInfo::Type::TokenPaste: - case PathInfo::Type::CommandLine: - case PathInfo::Type::TypeParse: - { - // If any of these types we don't output - return; - } - default: - break; - } - - // Okay output this file at the current depth - _outputInclude(sourceFile, depth, sink); - - // Now recurse to all of the children at the next depth - List<SourceView*>* children = hierarchy.tryGetValue(sourceView); - if (children) - { - for (SourceView* child : *children) - { - _outputIncludesRec(child, depth + 1, hierarchy, sink); - } - } -} - -static void _outputPreprocessorTokens(const TokenList& toks, ISlangWriter* writer) -{ - if (writer == nullptr) - { - return; - } - - StringBuilder buf; - for (const auto& tok : toks) - { - buf << tok.getContent(); - // We'll separate tokens with space for now - buf.appendChar(' '); - } - - buf.appendChar('\n'); - - writer->write(buf.getBuffer(), buf.getLength()); -} - -static void _outputIncludes( - const List<SourceFile*>& sourceFiles, - SourceManager* sourceManager, - DiagnosticSink* sink) -{ - // Set up the hierarchy to know how all the source views relate. This could be argued as - // overkill, but makes recursive output pretty simple - ViewInitiatingHierarchy hierarchy; - _calcViewInitiatingHierarchy(sourceManager, hierarchy); - - // For all the source files - for (SourceFile* sourceFile : sourceFiles) - { - if (sourceFile->isIncludedFile()) - continue; - - // Find an initial view (this is the view of this file, that doesn't have an initiating loc) - SourceView* sourceView = _findInitialSourceView(sourceFile); - if (!sourceView) - { - // Okay, didn't find one, so just output the file - _outputInclude(sourceFile, 0, sink); - } - else - { - // Output from this view recursively - _outputIncludesRec(sourceView, 0, hierarchy, sink); - } - } -} - -void FrontEndCompileRequest::parseTranslationUnit(TranslationUnitRequest* translationUnit) -{ - SLANG_PROFILE; - if (translationUnit->isChecked) - return; - - auto linkage = getLinkage(); - - SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); - - // TODO(JS): NOTE! Here we are using the searchDirectories on the linkage. This is because - // currently the API only allows the setting search paths on linkage. - // - // Here we should probably be using the searchDirectories on the FrontEndCompileRequest. - // If searchDirectories.parent pointed to the one in the Linkage would mean linkage paths - // would be checked too (after those on the FrontEndCompileRequest). - IncludeSystem includeSystem( - &linkage->getSearchDirectories(), - linkage->getFileSystemExt(), - linkage->getSourceManager()); - - auto combinedPreprocessorDefinitions = translationUnit->getCombinedPreprocessorDefinitions(); - - auto module = translationUnit->getModule(); - - ASTBuilder* astBuilder = module->getASTBuilder(); - - ModuleDecl* translationUnitSyntax = astBuilder->create<ModuleDecl>(); - - translationUnitSyntax->nameAndLoc.name = translationUnit->moduleName; - translationUnitSyntax->module = module; - module->setModuleDecl(translationUnitSyntax); - - // When compiling a module of code that belongs to the Slang - // core module, we add a modifier to the module to act - // as a marker, so that downstream code can detect declarations - // that came from the core module (by walking up their - // chain of ancestors and looking for the marker), and treat - // them differently from user declarations. - // - // We are adding the marker here, before we even parse the - // code in the module, in case the subsequent steps would - // like to treat the core module differently. Alternatively - // we could pass down the `m_isStandardLibraryCode` flag to - // these passes. - // - if (m_isCoreModuleCode) - { - translationUnitSyntax->modifiers.first = astBuilder->create<FromCoreModuleModifier>(); - } - - // We use a custom handler for preprocessor callbacks, to - // ensure that relevant state that is only visible during - // preprocessoing can be communicated to later phases of - // compilation. - // - FrontEndPreprocessorHandler preprocessorHandler(module, astBuilder, getSink(), translationUnit); - - for (auto sourceFile : translationUnit->getSourceFiles()) - { - module->getIncludedSourceFileMap().addIfNotExists(sourceFile, nullptr); - } - - for (auto sourceFile : translationUnit->getSourceFiles()) - { - SourceLanguage sourceLanguage = translationUnit->sourceLanguage; - SlangLanguageVersion languageVersion = - translationUnit->compileRequest->optionSet.getLanguageVersion(); - auto tokens = preprocessSource( - sourceFile, - getSink(), - &includeSystem, - combinedPreprocessorDefinitions, - getLinkage(), - sourceLanguage, - languageVersion, - &preprocessorHandler); - - translationUnitSyntax->languageVersion = languageVersion; - - if (sourceLanguage == SourceLanguage::Unknown) - sourceLanguage = translationUnit->sourceLanguage; - - Scope* languageScope = nullptr; - switch (sourceLanguage) - { - case SourceLanguage::HLSL: - languageScope = getSession()->hlslLanguageScope; - break; - case SourceLanguage::GLSL: - languageScope = getSession()->glslLanguageScope; - break; - case SourceLanguage::Slang: - default: - languageScope = getSession()->slangLanguageScope; - break; - } - - if (optionSet.getBoolOption(CompilerOptionName::OutputIncludes)) - { - _outputIncludes( - translationUnit->getSourceFiles(), - getSink()->getSourceManager(), - getSink()); - } - - if (optionSet.getBoolOption(CompilerOptionName::PreprocessorOutput)) - { - if (m_writers) - { - _outputPreprocessorTokens( - tokens, - m_writers->getWriter(SLANG_WRITER_CHANNEL_STD_OUTPUT)); - } - // If we output the preprocessor output then we are done doing anything else - return; - } - - parseSourceFile( - astBuilder, - translationUnit, - sourceLanguage, - tokens, - getSink(), - languageScope, - translationUnitSyntax); - - // Let's try dumping - - if (optionSet.getBoolOption(CompilerOptionName::DumpAst)) - { - StringBuilder buf; - SourceWriter writer(linkage->getSourceManager(), LineDirectiveMode::None, nullptr); - - ASTDumpUtil::dump( - translationUnit->getModuleDecl(), - ASTDumpUtil::Style::Flat, - 0, - &writer); - - const String& path = sourceFile->getPathInfo().foundPath; - if (path.getLength()) - { - String fileName = Path::getFileNameWithoutExt(path); - fileName.append(".slang-ast"); - - File::writeAllText(fileName, writer.getContent()); - } - } - -#if 0 - // Test serialization - { - ASTSerialTestUtil::testSerialize(translationUnit->getModuleDecl(), getSession()->getNamePool(), getLinkage()->getASTBuilder()->getSharedASTBuilder(), getSourceManager()); - } -#endif - } -} - -RefPtr<ComponentType> createUnspecializedGlobalComponentType( - FrontEndCompileRequest* compileRequest); - -RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType( - FrontEndCompileRequest* compileRequest, - List<RefPtr<ComponentType>>& outUnspecializedEntryPoints); - -RefPtr<ComponentType> createSpecializedGlobalComponentType(EndToEndCompileRequest* endToEndReq); - -RefPtr<ComponentType> createSpecializedGlobalAndEntryPointsComponentType( - EndToEndCompileRequest* endToEndReq, - List<RefPtr<ComponentType>>& outSpecializedEntryPoints); - -void FrontEndCompileRequest::checkAllTranslationUnits() -{ - SLANG_PROFILE; - - LoadedModuleDictionary loadedModules; - if (additionalLoadedModules) - loadedModules = *additionalLoadedModules; - - // Iterate over all translation units and - // apply the semantic checking logic. - for (auto& translationUnit : translationUnits) - { - if (translationUnit->isChecked) - continue; - - checkTranslationUnit(translationUnit.Ptr(), loadedModules); - - // Add the checked module to list of loadedModules so that they can be - // discovered by `findOrImportModule` when processing future `import` decls. - // TODO: this does not handle the case where a translation unit to discover - // another translation unit added later to the compilation request. - // We should output an error message when we detect such a case, or support - // this scenario with a recursive style checking. - loadedModules.add(translationUnit->moduleName, translationUnit->getModule()); - } - checkEntryPoints(); -} - -void FrontEndCompileRequest::generateIR() -{ - SLANG_PROFILE; - SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder()); - - // Our task in this function is to generate IR code - // for all of the declarations in the translation - // units that were loaded. - - // Each translation unit is its own little world - // for code generation (we are not trying to - // replicate the GLSL linkage model), and so - // we will generate IR for each (if needed) - // in isolation. - for (auto& translationUnit : translationUnits) - { - // Skip if the module is precompiled. - if (translationUnit->getModule()->getIRModule()) - continue; - - // We want to only run generateIRForTranslationUnit once here. This is for two side effects: - // * it can dump ir - // * it can generate diagnostics - - /// Generate IR for translation unit. - RefPtr<IRModule> irModule( - generateIRForTranslationUnit(getLinkage()->getASTBuilder(), translationUnit)); - - if (verifyDebugSerialization) - { - SerialContainerUtil::WriteOptions options; - - options.sourceManagerToUseWhenSerializingSourceLocs = getSourceManager(); - - // Verify debug information - if (SLANG_FAILED( - SerialContainerUtil::verifyIRSerialize(irModule, getSession(), options))) - { - getSink()->diagnose( - irModule->getModuleInst()->sourceLoc, - Diagnostics::serialDebugVerificationFailed); - } - } - - // Set the module on the translation unit - translationUnit->getModule()->setIRModule(irModule); - } -} - -// Try to infer a single common source language for a request -static SourceLanguage inferSourceLanguage(FrontEndCompileRequest* request) -{ - SourceLanguage language = SourceLanguage::Unknown; - for (auto& translationUnit : request->translationUnits) - { - // Allow any other language to overide Slang as a choice - if (language == SourceLanguage::Unknown || language == SourceLanguage::Slang) - { - language = translationUnit->sourceLanguage; - } - else if (language == translationUnit->sourceLanguage) - { - // same language as we currently have, so keep going - } - else - { - // we found a mismatch, so inference fails - return SourceLanguage::Unknown; - } - } - return language; -} - -SlangResult FrontEndCompileRequest::executeActionsInner() -{ - SLANG_PROFILE_SECTION(frontEndExecute); - SLANG_AST_BUILDER_RAII(getLinkage()->getASTBuilder()); - - for (TranslationUnitRequest* translationUnit : translationUnits) - { - // Make sure SourceFile representation is available for all translationUnits - SLANG_RETURN_ON_FAIL(translationUnit->requireSourceFiles()); - } - - - // Parse everything from the input files requested - for (TranslationUnitRequest* translationUnit : translationUnits) - { - parseTranslationUnit(translationUnit); - } - - if (optionSet.getBoolOption(CompilerOptionName::PreprocessorOutput)) - { - // If doing pre-processor output, then we are done - return SLANG_OK; - } - - if (getSink()->getErrorCount() != 0) - return SLANG_FAIL; - - // Perform semantic checking on the whole collection - { - SLANG_PROFILE_SECTION(SemanticChecking); - checkAllTranslationUnits(); - } - - if (getSink()->getErrorCount() != 0) - return SLANG_FAIL; - - // After semantic checking is performed we can try and output doc information for this - if (optionSet.getBoolOption(CompilerOptionName::Doc)) - { - // TODO: implement the logic to output generated documents to target directory/zip file. - } - - // Look up all the entry points that are expected, - // and use them to populate the `program` member. - // - m_globalComponentType = createUnspecializedGlobalComponentType(this); - if (getSink()->getErrorCount() != 0) - return SLANG_FAIL; - - m_globalAndEntryPointsComponentType = - createUnspecializedGlobalAndEntryPointsComponentType(this, m_unspecializedEntryPoints); - if (getSink()->getErrorCount() != 0) - return SLANG_FAIL; - - // We always generate IR for all the translation units. - // - // TODO: We may eventually have a mode where we skip - // IR codegen and only produce an AST (e.g., for use when - // debugging problems in the parser or semantic checking), - // but for now there are no cases where not having IR - // makes sense. - // - generateIR(); - if (getSink()->getErrorCount() != 0) - return SLANG_FAIL; - - // Do parameter binding generation, for each compilation target. - // - for (auto targetReq : getLinkage()->targets) - { - auto targetProgram = m_globalAndEntryPointsComponentType->getTargetProgram(targetReq); - targetProgram->getOrCreateLayout(getSink()); - targetProgram->getOrCreateIRModuleForLayout(getSink()); - } - if (getSink()->getErrorCount() != 0) - return SLANG_FAIL; - - return SLANG_OK; -} - -EndToEndCompileRequest::EndToEndCompileRequest(Session* session) - : m_session(session), m_sink(nullptr, Lexer::sourceLocationLexer) -{ - RefPtr<ASTBuilder> astBuilder( - new ASTBuilder(session->m_sharedASTBuilder, "EndToEnd::Linkage::astBuilder")); - m_linkage = new Linkage(session, astBuilder, session->getBuiltinLinkage()); - init(); -} - -EndToEndCompileRequest::EndToEndCompileRequest(Linkage* linkage) - : m_session(linkage->getSessionImpl()) - , m_linkage(linkage) - , m_sink(nullptr, Lexer::sourceLocationLexer) -{ - init(); -} - -SLANG_NO_THROW SlangResult SLANG_MCALL -EndToEndCompileRequest::queryInterface(SlangUUID const& uuid, void** outObject) -{ - if (uuid == EndToEndCompileRequest::getTypeGuid()) - { - // Special case to cast directly into internal type - // NOTE! No addref(!) - *outObject = this; - return SLANG_OK; - } - - if (uuid == ISlangUnknown::getTypeGuid() && uuid == ICompileRequest::getTypeGuid()) - { - addReference(); - *outObject = static_cast<slang::ICompileRequest*>(this); - return SLANG_OK; - } - - return SLANG_E_NO_INTERFACE; -} - -void EndToEndCompileRequest::init() -{ - m_sink.setSourceManager(m_linkage->getSourceManager()); - - m_writers = new StdWriters; - - // Set all the default writers - for (int i = 0; i < int(WriterChannel::CountOf); ++i) - { - setWriter(WriterChannel(i), nullptr); - } - - m_frontEndReq = new FrontEndCompileRequest(getLinkage(), m_writers, getSink()); -} - -SlangResult EndToEndCompileRequest::executeActionsInner() -{ - SLANG_PROFILE_SECTION(endToEndActions); - // If no code-generation target was specified, then try to infer one from the source language, - // just to make sure we can do something reasonable when invoked from the command line. - // - // TODO: This logic should be moved into `options.cpp` or somewhere else - // specific to the command-line tool. - // - if (getLinkage()->targets.getCount() == 0) - { - auto language = inferSourceLanguage(getFrontEndReq()); - switch (language) - { - case SourceLanguage::HLSL: - getLinkage()->addTarget(CodeGenTarget::DXBytecode); - break; - - case SourceLanguage::GLSL: - getLinkage()->addTarget(CodeGenTarget::SPIRV); - break; - - default: - break; - } - } - - // Update compiler settings in target requests. - for (auto target : getLinkage()->targets) - target->getOptionSet().inheritFrom(getOptionSet()); - m_frontEndReq->optionSet = getOptionSet(); - - // We only do parsing and semantic checking if we *aren't* doing - // a pass-through compilation. - // - if (m_passThrough == PassThroughMode::None) - { - SLANG_RETURN_ON_FAIL(getFrontEndReq()->executeActionsInner()); - } - - if (getOptionSet().getBoolOption(CompilerOptionName::PreprocessorOutput)) - { - return SLANG_OK; - } - - // If command line specifies to skip codegen, we exit here. - // Note: this is a debugging option. - // - if (getOptionSet().getBoolOption(CompilerOptionName::SkipCodeGen)) - { - // We will use the program (and matching layout information) - // that was computed in the front-end for all subsequent - // reflection queries, etc. - // - m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); - m_specializedGlobalAndEntryPointsComponentType = - getUnspecializedGlobalAndEntryPointsComponentType(); - m_specializedEntryPoints = getFrontEndReq()->getUnspecializedEntryPoints(); - - SLANG_RETURN_ON_FAIL(maybeCreateContainer()); - - SLANG_RETURN_ON_FAIL(maybeWriteContainer(m_containerOutputPath)); - - return SLANG_OK; - } - - // If requested, attempt to compile the translation unit all the way down to the target - // language(s) and stash the result blobs in IR. - for (auto target : getLinkage()->targets) - { - SlangCompileTarget targetEnum = SlangCompileTarget(target->getTarget()); - if (target->getOptionSet().getBoolOption(CompilerOptionName::EmbedDownstreamIR)) - { - auto frontEndReq = getFrontEndReq(); - - for (auto translationUnit : frontEndReq->translationUnits) - { - SLANG_RETURN_ON_FAIL( - translationUnit->getModule()->precompileForTarget(targetEnum, nullptr)); - - if (frontEndReq->optionSet.shouldDumpIR()) - { - DiagnosticSinkWriter writer(frontEndReq->getSink()); - - dumpIR( - translationUnit->getModule()->getIRModule(), - frontEndReq->m_irDumpOptions, - "PRECOMPILE_FOR_TARGET_COMPLETE_ALL", - frontEndReq->getSourceManager(), - &writer); - - dumpIR( - translationUnit->getModule()->getIRModule()->getModuleInst(), - frontEndReq->m_irDumpOptions, - frontEndReq->getSourceManager(), - &writer); - } - } - } - } - - // If codegen is enabled, we need to move along to - // apply any generic specialization that the user asked for. - // - if (m_passThrough == PassThroughMode::None) - { - m_specializedGlobalComponentType = createSpecializedGlobalComponentType(this); - if (getSink()->getErrorCount() != 0) - return SLANG_FAIL; - - m_specializedGlobalAndEntryPointsComponentType = - createSpecializedGlobalAndEntryPointsComponentType(this, m_specializedEntryPoints); - if (getSink()->getErrorCount() != 0) - return SLANG_FAIL; - - // For each code generation target, we will generate specialized - // parameter binding information (taking global generic - // arguments into account at this time). - // - for (auto targetReq : getLinkage()->targets) - { - auto targetProgram = - m_specializedGlobalAndEntryPointsComponentType->getTargetProgram(targetReq); - targetProgram->getOrCreateLayout(getSink()); - } - if (getSink()->getErrorCount() != 0) - return SLANG_FAIL; - } - else - { - // We need to create dummy `EntryPoint` objects - // to make sure that the logic in `generateOutput` - // sees something worth processing. - // - List<RefPtr<ComponentType>> dummyEntryPoints; - for (auto entryPointReq : getFrontEndReq()->getEntryPointReqs()) - { - RefPtr<EntryPoint> dummyEntryPoint = EntryPoint::createDummyForPassThrough( - getLinkage(), - entryPointReq->getName(), - entryPointReq->getProfile()); - - dummyEntryPoints.add(dummyEntryPoint); - } - - RefPtr<ComponentType> composedProgram = - CompositeComponentType::create(getLinkage(), dummyEntryPoints); - - m_specializedGlobalComponentType = getUnspecializedGlobalComponentType(); - m_specializedGlobalAndEntryPointsComponentType = composedProgram; - m_specializedEntryPoints = getFrontEndReq()->getUnspecializedEntryPoints(); - } - - // Generate output code, in whatever format was requested - generateOutput(); - if (getSink()->getErrorCount() != 0) - return SLANG_FAIL; - - return SLANG_OK; -} - -// Act as expected of the API-based compiler -SlangResult EndToEndCompileRequest::executeActions() -{ - SlangResult res = executeActionsInner(); - - m_diagnosticOutput = getSink()->outputBuffer.produceString(); - return res; -} - -int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* moduleName) -{ - RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(this); - translationUnit->compileRequest = this; - translationUnit->sourceLanguage = SourceLanguage(language); - - translationUnit->setModuleName(moduleName); - return addTranslationUnit(translationUnit); -} - -int FrontEndCompileRequest::addTranslationUnit(TranslationUnitRequest* translationUnit) -{ - Index result = translationUnits.getCount(); - translationUnits.add(translationUnit); - return (int)result; -} - -void FrontEndCompileRequest::addTranslationUnitSourceArtifact( - int translationUnitIndex, - IArtifact* sourceArtifact) -{ - auto translationUnit = translationUnits[translationUnitIndex]; - - // Add the source file - translationUnit->addSourceArtifact(sourceArtifact); - - if (!translationUnit->moduleName) - { - translationUnit->setModuleName( - getNamePool()->getName(Path::getFileNameWithoutExt(sourceArtifact->getName()))); - } - if (translationUnit->module->getFilePath() == nullptr) - translationUnit->module->setPathInfo(PathInfo::makePath(sourceArtifact->getName())); -} - -void FrontEndCompileRequest::addTranslationUnitSourceBlob( - int translationUnitIndex, - String const& path, - ISlangBlob* sourceBlob) -{ - auto translationUnit = translationUnits[translationUnitIndex]; - auto sourceDesc = - ArtifactDescUtil::makeDescForSourceLanguage(asExternal(translationUnit->sourceLanguage)); - - auto artifact = ArtifactUtil::createArtifact(sourceDesc, path.getBuffer()); - artifact->addRepresentationUnknown(sourceBlob); - - addTranslationUnitSourceArtifact(translationUnitIndex, artifact); -} - -void FrontEndCompileRequest::addTranslationUnitSourceFile( - int translationUnitIndex, - String const& path) -{ - // TODO: We need to consider whether a relative `path` should cause - // us to look things up using the registered search paths. - // - // This behavior wouldn't make sense for command-line invocations - // of `slangc`, but at least one API user wondered by the search - // paths were not taken into account by this function. - // - - auto fileSystemExt = getLinkage()->getFileSystemExt(); - auto translationUnit = getTranslationUnit(translationUnitIndex); - - auto sourceDesc = - ArtifactDescUtil::makeDescForSourceLanguage(asExternal(translationUnit->sourceLanguage)); - - auto sourceArtifact = ArtifactUtil::createArtifact(sourceDesc, path.getBuffer()); - - auto extRep = new ExtFileArtifactRepresentation(path.getUnownedSlice(), fileSystemExt); - sourceArtifact->addRepresentation(extRep); - - SlangResult existsRes = SLANG_OK; - - // If we require caching, we demand it's loaded here. - // - // In practice this probably means repro capture is enabled. So we want to - // load the blob such that it's in the cache, even if it doesn't actually - // have to be loaded for the compilation. - if (getLinkage()->m_requireCacheFileSystem) - { - ComPtr<ISlangBlob> blob; - // If we can load the blob, then it exists - existsRes = sourceArtifact->loadBlob(ArtifactKeep::Yes, blob.writeRef()); - } - else - { - existsRes = sourceArtifact->exists() ? SLANG_OK : SLANG_E_NOT_FOUND; - } - - if (SLANG_FAILED(existsRes)) - { - // Emit a diagnostic! - getSink()->diagnose(SourceLoc(), Diagnostics::cannotOpenFile, path); - return; - } - - addTranslationUnitSourceArtifact(translationUnitIndex, sourceArtifact); -} - -int FrontEndCompileRequest::addEntryPoint( - int translationUnitIndex, - String const& name, - Profile entryPointProfile) -{ - auto translationUnitReq = translationUnits[translationUnitIndex]; - - Index result = m_entryPointReqs.getCount(); - - RefPtr<FrontEndEntryPointRequest> entryPointReq = new FrontEndEntryPointRequest( - this, - translationUnitIndex, - getNamePool()->getName(name), - entryPointProfile); - - m_entryPointReqs.add(entryPointReq); - // translationUnitReq->entryPoints.add(entryPointReq); - - return int(result); -} - -int EndToEndCompileRequest::addEntryPoint( - int translationUnitIndex, - String const& name, - Profile entryPointProfile, - List<String> const& genericTypeNames) -{ - getFrontEndReq()->addEntryPoint(translationUnitIndex, name, entryPointProfile); - - EntryPointInfo entryPointInfo; - for (auto typeName : genericTypeNames) - entryPointInfo.specializationArgStrings.add(typeName); - - Index result = m_entryPoints.getCount(); - m_entryPoints.add(_Move(entryPointInfo)); - return (int)result; -} - -UInt Linkage::addTarget(CodeGenTarget target) -{ - RefPtr<TargetRequest> targetReq = new TargetRequest(this, target); - - Index result = targets.getCount(); - targets.add(targetReq); - return UInt(result); -} - -void Linkage::loadParsedModule( - RefPtr<FrontEndCompileRequest> compileRequest, - RefPtr<TranslationUnitRequest> translationUnit, - Name* name, - const PathInfo& pathInfo) -{ - // Note: we add the loaded module to our name->module listing - // before doing semantic checking, so that if it tries to - // recursively `import` itself, we can detect it. - // - RefPtr<Module> loadedModule = translationUnit->getModule(); - - // Get a path - String mostUniqueIdentity = pathInfo.getMostUniqueIdentity(); - SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); - - mapPathToLoadedModule.add(mostUniqueIdentity, loadedModule); - mapNameToLoadedModules.add(name, loadedModule); - - auto sink = translationUnit->compileRequest->getSink(); - - int errorCountBefore = sink->getErrorCount(); - int errorCountAfter; - try - { - compileRequest->checkAllTranslationUnits(); - } - catch (...) - { - mapPathToLoadedModule.remove(mostUniqueIdentity); - mapNameToLoadedModules.remove(name); - throw; - } - errorCountAfter = sink->getErrorCount(); - if (isInLanguageServer()) - { - // Don't generate IR as language server. - // This means that we currently cannot report errors that are detected during IR passes. - // Ideally we want to run those passes, but that is too risky for what it is worth right - // now. - } - else - { - if (errorCountAfter != errorCountBefore) - { - // There must have been an error in the loaded module. - // Remove from maps if there were errors during semantic checking - mapPathToLoadedModule.remove(mostUniqueIdentity); - mapNameToLoadedModules.remove(name); - } - else - { - // If we didn't run into any errors, then try to generate - // IR code for the imported module. - if (errorCountAfter == 0) - { - loadedModule->setIRModule( - generateIRForTranslationUnit(getASTBuilder(), translationUnit)); - } - } - } - loadedModulesList.add(loadedModule); -} - -RefPtr<Module> Linkage::findOrLoadSerializedModuleForModuleLibrary( - ISlangBlob* blobHoldingSerializedData, - ModuleChunk const* moduleChunk, - RIFF::ListChunk const* libraryChunk, - DiagnosticSink* sink) -{ - RefPtr<Module> resultModule; - - // We will attempt things in a few different steps, trying to - // decode as little of the serialized module as necessary at - // each step, so that we don't waste time on the heavyweight - // stuff when we didn't need to. - // - // The first step is to simply decode the module name, and - // see if we have a already loaded a matching module. - - auto moduleName = getNamePool()->getName(moduleChunk->getName()); - if (mapNameToLoadedModules.tryGetValue(moduleName, resultModule)) - return resultModule; - - // It is possible that the module has been loaded, but somehow - // under a different name, so next we decode the list of file - // paths that the module depends on, and then rely on the assumption - // that the first of those paths represents the file for the module - // itself to detect if we've already loaded a module from that - // path. - // - // Note: While this is a distasteful assumption to make, it is - // one that gets made in several parts of the compiler codebase - // already. It isn't something that can be fixed in just one - // place at this point. - - auto fileDependenciesList = moduleChunk->getFileDependencies(); - auto firstFileDependencyChunk = fileDependenciesList.getFirst(); - if (!firstFileDependencyChunk) - return nullptr; - - auto modulePathInfo = PathInfo::makePath(firstFileDependencyChunk->getValue()); - if (mapPathToLoadedModule.tryGetValue(modulePathInfo.getMostUniqueIdentity(), resultModule)) - return resultModule; - - // If we failed to find a previously-loaded module, then we - // will go ahead and load the module from the serialized form. - // - PathInfo filePathInfo; - return loadSerializedModule( - moduleName, - modulePathInfo, - blobHoldingSerializedData, - moduleChunk, - libraryChunk, - SourceLoc(), - sink); -} - -RefPtr<Module> Linkage::loadSerializedModule( - Name* moduleName, - const PathInfo& moduleFilePathInfo, - ISlangBlob* blobHoldingSerializedData, - ModuleChunk const* moduleChunk, - RIFF::ListChunk const* containerChunk, - SourceLoc const& requestingLoc, - DiagnosticSink* sink) -{ - auto astBuilder = getASTBuilder(); - SLANG_AST_BUILDER_RAII(astBuilder); - - auto module = RefPtr(new Module(this, astBuilder)); - module->setName(moduleName); - - // Just as if we were processing an `import` declaration in - // source code, we will track the fact that this serialized - // modlue is (effectively) being imported, so that we can - // diagnose anything troublesome, like an attempt at a - // recursive import. - // - ModuleBeingImportedRAII moduleBeingImported(this, module, moduleName, requestingLoc); - - // We will register the module in our data structures to - // track loaded modules, and then remove it in the case - // where there is some kind of failure. - // - String mostUniqueIdentity = moduleFilePathInfo.getMostUniqueIdentity(); - SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); - - mapPathToLoadedModule.add(mostUniqueIdentity, module); - mapNameToLoadedModules.add(moduleName, module); - try - { - if (SLANG_FAILED(loadSerializedModuleContents( - module, - moduleFilePathInfo, - blobHoldingSerializedData, - moduleChunk, - containerChunk, - sink))) - { - mapPathToLoadedModule.remove(mostUniqueIdentity); - mapNameToLoadedModules.remove(moduleName); - return nullptr; - } - - loadedModulesList.add(module); - return module; - } - catch (...) - { - mapPathToLoadedModule.remove(mostUniqueIdentity); - mapNameToLoadedModules.remove(moduleName); - throw; - } -} - -RefPtr<Module> Linkage::loadBinaryModuleImpl( - Name* moduleName, - const PathInfo& moduleFilePathInfo, - ISlangBlob* moduleFileContents, - SourceLoc const& requestingLoc, - DiagnosticSink* sink) -{ - auto astBuilder = getASTBuilder(); - SLANG_AST_BUILDER_RAII(astBuilder); - - // We start by reading the content of the file as - // an in-memory RIFF container. - // - auto rootChunk = RIFF::RootChunk::getFromBlob(moduleFileContents); - if (!rootChunk) - { - return nullptr; - } - - auto moduleChunk = ModuleChunk::find(rootChunk); - if (!moduleChunk) - { - return nullptr; - } - - // Next, we attempt to check if the binary module is up to - // date with the compilation options in use as well as - // the contents of all the files its compilation depended - // on (as determined by its hash). - // - String mostUniqueIdentity = moduleFilePathInfo.getMostUniqueIdentity(); - SLANG_ASSERT(mostUniqueIdentity.getLength() > 0); - if (m_optionSet.getBoolOption(CompilerOptionName::UseUpToDateBinaryModule)) - { - if (!isBinaryModuleUpToDate(moduleFilePathInfo.foundPath, moduleChunk)) - { - return nullptr; - } - } - - // If everything seems reasonable, then we will go ahead and load - // the module more completely from that serialized representation. - // - RefPtr<Module> module = loadSerializedModule( - moduleName, - moduleFilePathInfo, - moduleFileContents, - moduleChunk, - rootChunk, - requestingLoc, - sink); - - return module; -} - -void Linkage::_diagnoseErrorInImportedModule(DiagnosticSink* sink) -{ - for (auto info = m_modulesBeingImported; info; info = info->next) - { - sink->diagnose(info->importLoc, Diagnostics::errorInImportedModule, info->name); - } - if (!isInLanguageServer()) - { - sink->diagnose(SourceLoc(), Diagnostics::complationCeased); - } -} - -RefPtr<Module> Linkage::loadModuleImpl( - Name* moduleName, - const PathInfo& modulePathInfo, - ISlangBlob* moduleBlob, - SourceLoc const& requestingLoc, - DiagnosticSink* sink, - const LoadedModuleDictionary* additionalLoadedModules, - ModuleBlobType blobType) -{ - switch (blobType) - { - case ModuleBlobType::IR: - return loadBinaryModuleImpl(moduleName, modulePathInfo, moduleBlob, requestingLoc, sink); - - case ModuleBlobType::Source: - return loadSourceModuleImpl( - moduleName, - modulePathInfo, - moduleBlob, - requestingLoc, - sink, - additionalLoadedModules); - - default: - SLANG_UNEXPECTED("unknown module blob type"); - UNREACHABLE_RETURN(nullptr); - } -} - -RefPtr<Module> Linkage::loadSourceModuleImpl( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* sourceBlob, - SourceLoc const& srcLoc, - DiagnosticSink* sink, - const LoadedModuleDictionary* additionalLoadedModules) -{ - RefPtr<FrontEndCompileRequest> frontEndReq = new FrontEndCompileRequest(this, nullptr, sink); - - frontEndReq->additionalLoadedModules = additionalLoadedModules; - - RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(frontEndReq); - translationUnit->compileRequest = frontEndReq; - translationUnit->setModuleName(name); - Stage impliedStage; - translationUnit->sourceLanguage = SourceLanguage::Slang; - - // If we are loading from a file with apparaent glsl extension, - // set the source language to GLSL to enable GLSL compatibility mode. - if ((SourceLanguage)findSourceLanguageFromPath(filePathInfo.getName(), impliedStage) == - SourceLanguage::GLSL) - { - translationUnit->sourceLanguage = SourceLanguage::GLSL; - } - - frontEndReq->addTranslationUnit(translationUnit); - - auto module = translationUnit->getModule(); - - ModuleBeingImportedRAII moduleBeingImported(this, module, name, srcLoc); - - // Create an artifact for the source - auto sourceArtifact = ArtifactUtil::createArtifact( - ArtifactDesc::make(ArtifactKind::Source, ArtifactPayload::Slang, ArtifactStyle::Unknown)); - - if (sourceBlob) - { - // If the user has already provided a source blob, use that. - sourceArtifact->addRepresentation( - new SourceBlobWithPathInfoArtifactRepresentation(filePathInfo, sourceBlob)); - } - else if ( - filePathInfo.type == PathInfo::Type::Normal || - filePathInfo.type == PathInfo::Type::FoundPath) - { - // Create with the 'friendly' name - // We create that it was loaded from the file system - sourceArtifact->addRepresentation(new ExtFileArtifactRepresentation( - filePathInfo.foundPath.getUnownedSlice(), - getFileSystemExt())); - } - else - { - return nullptr; - } - - translationUnit->addSourceArtifact(sourceArtifact); - - if (SLANG_FAILED(translationUnit->requireSourceFiles())) - { - // Some problem accessing source files - return nullptr; - } - int errorCountBefore = sink->getErrorCount(); - frontEndReq->parseTranslationUnit(translationUnit); - int errorCountAfter = sink->getErrorCount(); - - if (errorCountAfter != errorCountBefore && !isInLanguageServer()) - { - _diagnoseErrorInImportedModule(sink); - // Something went wrong during the parsing, so we should bail out. - return nullptr; - } - - try - { - loadParsedModule(frontEndReq, translationUnit, name, filePathInfo); - } - catch (const Slang::AbortCompilationException&) - { - // Something is fatally wrong, we should return nullptr. - module = nullptr; - } - errorCountAfter = sink->getErrorCount(); - - if (errorCountAfter != errorCountBefore && !isInLanguageServer()) - { - // If something is fatally wrong, we want to report - // the diagnostic even if we are in language server - // and processing a different module. - _diagnoseErrorInImportedModule(sink); - // Something went wrong during the parsing, so we should bail out. - return nullptr; - } - - if (!module) - return nullptr; - - module->setPathInfo(filePathInfo); - return module; -} - -bool Linkage::isBeingImported(Module* module) -{ - for (auto ii = m_modulesBeingImported; ii; ii = ii->next) - { - if (module == ii->module) - return true; - } - return false; -} - -// Derive a file name for the module, by taking the given -// identifier, replacing all occurrences of `_` with `-`, -// and then appending `.slang`. -// -// For example, `foo_bar` becomes `foo-bar.slang`. -String getFileNameFromModuleName(Name* name, bool translateUnderScore) -{ - String fileName; - if (!getText(name).getUnownedSlice().endsWithCaseInsensitive(".slang")) - { - StringBuilder sb; - for (auto c : getText(name)) - { - if (translateUnderScore && c == '_') - c = '-'; - - sb.append(c); - } - sb.append(".slang"); - fileName = sb.produceString(); - } - else - { - fileName = getText(name); - } - return fileName; -} - -RefPtr<Module> Linkage::findOrImportModule( - Name* moduleName, - SourceLoc const& requestingLoc, - DiagnosticSink* sink, - const LoadedModuleDictionary* loadedModules) -{ - // Have we already loaded a module matching this name? - // - RefPtr<LoadedModule> previouslyLoadedModule; - if (mapNameToLoadedModules.tryGetValue(moduleName, previouslyLoadedModule)) - { - // If the map shows a null module having been loaded, - // then that means there was a prior load attempt, - // but it failed, so we won't bother trying again. - // - if (!previouslyLoadedModule) - return nullptr; - - // If state shows us that the module is already being - // imported deeper on the call stack, then we've - // hit a recursive case, and that is an error. - // - if (isBeingImported(previouslyLoadedModule)) - { - // We seem to be in the middle of loading this module - sink->diagnose(requestingLoc, Diagnostics::recursiveModuleImport, moduleName); - return nullptr; - } - - return previouslyLoadedModule; - } - - // If the user is providing an additional list of loaded modules, we find - // if the module being imported is in that list. This allows a translation - // unit to use previously checked translation units in the same - // FrontEndCompileRequest. - { - Module* previouslyLoadedLocalModule = nullptr; - if (loadedModules && loadedModules->tryGetValue(moduleName, previouslyLoadedLocalModule)) - { - return previouslyLoadedLocalModule; - } - } - - // If the name being requested matches the name of a built-in module, - // then we will special-case the process by loading that builtin - // module directly. - // - // TODO: right now this logic is only considering the built-in `glsl` - // module, but it should probably be generalized so that we can more - // easily support having multiple built-in modules rather than just - // putting everything into `core`. - // - if (moduleName == getSessionImpl()->glslModuleName) - { - // This is a builtin glsl module, just load it from embedded definition. - auto glslModule = getSessionImpl()->getBuiltinModule(slang::BuiltinModuleName::GLSL); - if (!glslModule) - { - // Note: the way this logic is currently written, if the built-in - // `glsl` module fails to load, then we will *not* fall back to - // searching for a user-defined module in a file like `glsl.slang`. - // - // It is unclear if this should be the default behavior or not. - // Should built-in modules be prioritized over user modules? - // Should built-in modules shadow user modules, even when the - // built-in module fails to load, for some reason? - // - sink->diagnose(requestingLoc, Diagnostics::glslModuleNotAvailable, moduleName); - } - return glslModule; - } - - // We are going to use a loop to search for a suitable file to - // load the module from, to account for a few key choices: - // - // * We can both load modules from a source `.slang` file, - // or from a binary `.slang-module` file. - // - // * For a variety of reasons, the `import` logic has historically - // translated underscores in a module name into dashes (so that - // `import my_module` will look for `my-module.slang`), and we - // try to support both that convention as well as a convention - // that preserves underscores. - // - // To try to keep this logic as orthogonal as possible, we first - // construct lists of the options we want to iterate over, and - // then do the actual loop later. - - ShortList<ModuleBlobType, 2> typesToTry; - if (isInLanguageServer()) - { - // When in language server, we always prefer to use source module if it is available. - typesToTry.add(ModuleBlobType::Source); - typesToTry.add(ModuleBlobType::IR); - } - else - { - // Look for a precompiled module first, if not exist, load from source. - typesToTry.add(ModuleBlobType::IR); - typesToTry.add(ModuleBlobType::Source); - } - - // We will always search for a file name that directly matches the - // module name as written first, and then search for one with - // underscores replaced by dashes. The latter is the original - // behavior that `import` provided, but it seems safest to prefer - // the exact name spelled in the user's code when there might - // actually be ambiguity. - // - auto defaultSourceFileName = getFileNameFromModuleName(moduleName, false); - auto alternativeSourceFileName = getFileNameFromModuleName(moduleName, true); - String sourceFileNamesToTry[] = {defaultSourceFileName, alternativeSourceFileName}; - - // We are going to look for the candidate file using the same - // logic that would be used for a preprocessor `#include`, - // so we set up the necessary state. - // - IncludeSystem includeSystem(&getSearchDirectories(), getFileSystemExt(), getSourceManager()); - - // Just like with a `#include`, the search will take into - // account the path to the file where the request to import - // this module came from (e.g. the source file with the - // `import` declaration), if such a path is available. - // - PathInfo requestingPathInfo = - getSourceManager()->getPathInfo(requestingLoc, SourceLocType::Actual); - - for (auto type : typesToTry) - { - for (auto sourceFileName : sourceFileNamesToTry) - { - // The `sourceFileName` will have the `.slang` extension, - // so if we are looking for a binary module, we need - // to change the extension we will look for. - // - String fileName; - switch (type) - { - case ModuleBlobType::Source: - fileName = sourceFileName; - break; - - case ModuleBlobType::IR: - fileName = Path::replaceExt(sourceFileName, "slang-module"); - break; - } - - // We now search for a file matching the desired name, - // using the same logic as for a `#include`. - // - // TODO: We might want to consider how to handle the case - // of an `import` with a relative path a little specially, - // since it could in theory be possible for two `.slang` - // files with the same base name to exist in different - // directories in a project, and we'd want file-relative - // `import`s to work for each, without having either one - // be able to "claim" the bare identifier of the base - // name for itself. - // - PathInfo filePathInfo; - if (SLANG_FAILED( - includeSystem.findFile(fileName, requestingPathInfo.foundPath, filePathInfo))) - { - // If we failed to find the file at this step, we - // will continue the search for our other options. - // - continue; - } - - // We will *again* search for a previously loaded module. - // - // It is possible that the same file will have been loaded - // as a module under two different module names. The easiest - // way for this to happen is if there are `import` declarations - // using both the underscore and dash conventions (e.g., both - // `import "my-module.slang"` and `import my_module`). - // - // This case may also arise if one file `import`s a module using - // just an identifier for its name, but another `import`s it - // using a path (e.g., `import "subdir/file.slang"`). - // - // No matter how the situation arises, we only want to have one - // copy of the "same" module loaded at a given time, so we - // will re-use the existing module if we find one here. - // - if (mapPathToLoadedModule.tryGetValue( - filePathInfo.getMostUniqueIdentity(), - previouslyLoadedModule)) - { - // TODO: If we find a previously-loaded module at this step, - // then we should probably register that module under the - // given `moduleName` in the map of loaded modules, so - // that subsequent `import`s using the same form will find it. - // - return previouslyLoadedModule; - } - - // Now we try to load the content of the file. - // - // If for some reason we could find a file at the - // given path, but for some reason couldn't *open* - // and *read* it, then we continue the search - // using whatever other candidate file names are left. - // - ComPtr<ISlangBlob> fileContents; - if (SLANG_FAILED(includeSystem.loadFile(filePathInfo, fileContents))) - { - continue; - } - - // If we found a real file and were able to load its contents, - // then we'll go ahead and try to load a module from it, - // whether by compiling it or decoding the binary. - // - auto module = loadModuleImpl( - moduleName, - filePathInfo, - fileContents, - requestingLoc, - sink, - loadedModules, - type); - - // If the attempt to load the module from the given path - // was successful, we go ahead and use it, without trying - // out any other options. - // - if (module) - return module; - } - } - - // If we tried out all of our candidate file names - // and failed with each of them, then we diagnose - // an error based on the original *source* file - // name. - // - // TODO: this should really be an error message - // that clearly states something like "no file - // suitable for module `whatever` was found - // and loaded. - // - // Ideally that error message would include whatever - // of the candidate file names from the loop above - // got furthest along in the process (or just a - // list of the file names that were tried, if - // nothing was even found via the include system). - // - sink->diagnose(requestingLoc, Diagnostics::cannotOpenFile, defaultSourceFileName); - - // If the attempt to import the module failed, then - // we will stick a null pointer into the map of loaded - // modules, so that subsequent attempts to load a module - // with this name will return null without having to - // go through all the above steps yet again. - // - mapNameToLoadedModules[moduleName] = nullptr; - return nullptr; -} - -SourceFile* Linkage::loadSourceFile(String pathFrom, String path) -{ - IncludeSystem includeSystem(&getSearchDirectories(), getFileSystemExt(), getSourceManager()); - ComPtr<slang::IBlob> blob; - PathInfo pathInfo; - SLANG_RETURN_NULL_ON_FAIL(includeSystem.findFile(path, pathFrom, pathInfo)); - SourceFile* sourceFile = nullptr; - SLANG_RETURN_NULL_ON_FAIL(includeSystem.loadFile(pathInfo, blob, sourceFile)); - return sourceFile; -} - -// Check if a serialized module is up-to-date with current compiler options and source files. -bool Linkage::isBinaryModuleUpToDate(String fromPath, RIFF::ListChunk const* baseChunk) -{ - auto moduleChunk = ModuleChunk::find(baseChunk); - if (!moduleChunk) - return false; - - SHA1::Digest existingDigest = moduleChunk->getDigest(); - - DigestBuilder<SHA1> digestBuilder; - auto version = String(getBuildTagString()); - digestBuilder.append(version); - m_optionSet.buildHash(digestBuilder); - - // Find the canonical path of the directory containing the module source file. - String moduleSrcPath = ""; - - auto dependencyChunks = moduleChunk->getFileDependencies(); - if (auto firstDependencyChunk = dependencyChunks.getFirst()) - { - moduleSrcPath = firstDependencyChunk->getValue(); - - IncludeSystem includeSystem( - &getSearchDirectories(), - getFileSystemExt(), - getSourceManager()); - PathInfo modulePathInfo; - if (SLANG_SUCCEEDED(includeSystem.findFile(moduleSrcPath, fromPath, modulePathInfo))) - { - moduleSrcPath = modulePathInfo.foundPath; - Path::getCanonical(moduleSrcPath, moduleSrcPath); - } - } - - for (auto dependencyChunk : dependencyChunks) - { - auto file = dependencyChunk->getValue(); - auto sourceFile = loadSourceFile(fromPath, file); - if (!sourceFile) - { - // If we cannot find the source file from `fromPath`, - // try again from the module's source file path. - if (dependencyChunks.getFirst()) - sourceFile = loadSourceFile(moduleSrcPath, file); - } - if (!sourceFile) - return false; - digestBuilder.append(sourceFile->getDigest()); - } - return digestBuilder.finalize() == existingDigest; -} - -SLANG_NO_THROW bool SLANG_MCALL -Linkage::isBinaryModuleUpToDate(const char* modulePath, slang::IBlob* binaryModuleBlob) -{ - auto rootChunk = RIFF::RootChunk::getFromBlob(binaryModuleBlob); - if (!rootChunk) - return false; - return isBinaryModuleUpToDate(modulePath, rootChunk); -} - -SourceFile* Linkage::findFile(Name* name, SourceLoc loc, IncludeSystem& outIncludeSystem) -{ - auto impl = [&](bool translateUnderScore) -> SourceFile* - { - auto fileName = getFileNameFromModuleName(name, translateUnderScore); - - // Next, try to find the file of the given name, - // using our ordinary include-handling logic. - - auto& searchDirs = getSearchDirectories(); - outIncludeSystem = IncludeSystem(&searchDirs, getFileSystemExt(), getSourceManager()); - - // Get the original path info - PathInfo pathIncludedFromInfo = getSourceManager()->getPathInfo(loc, SourceLocType::Actual); - PathInfo filePathInfo; - - ComPtr<ISlangBlob> fileContents; - - // We have to load via the found path - as that is how file was originally loaded - if (SLANG_FAILED( - outIncludeSystem.findFile(fileName, pathIncludedFromInfo.foundPath, filePathInfo))) - { - return nullptr; - } - // Otherwise, try to load it. - SourceFile* sourceFile; - if (SLANG_FAILED(outIncludeSystem.loadFile(filePathInfo, fileContents, sourceFile))) - { - return nullptr; - } - return sourceFile; - }; - if (auto rs = impl(false)) - return rs; - return impl(true); -} - -Linkage::IncludeResult Linkage::findAndIncludeFile( - Module* module, - TranslationUnitRequest* translationUnit, - Name* name, - SourceLoc const& loc, - DiagnosticSink* sink) -{ - IncludeResult result; - result.fileDecl = nullptr; - result.isNew = false; - - IncludeSystem includeSystem; - auto sourceFile = findFile(name, loc, includeSystem); - if (!sourceFile) - { - sink->diagnose(loc, Diagnostics::cannotOpenFile, getText(name)); - return result; - } - - // If the file has already been included, don't need to do anything further. - if (auto existingFileDecl = module->getIncludedSourceFileMap().tryGetValue(sourceFile)) - { - result.fileDecl = *existingFileDecl; - result.isNew = false; - return result; - } - - if (isInLanguageServer()) - { - // HACK: When in language server mode, we will always load the currently opend file as a - // fresh module even if some previously opened file already references the current file via - // `import` or `include`. see comments in `WorkspaceVersion::getOrLoadModule()` for the - // reason behind this. An undesired outcome of this decision is that we could endup - // including the currently opened file itself via chain of `__include`s because the - // currently opened file will not have a true unique file system identity that allows it to - // be deduplicated correct. Therefore we insert a hack logic here to detect re-inclusion by - // just the file path. We can clean up this hack by making the language server truly support - // incremental checking so we can reuse the previously loaded module instead of needing to - // always start with a fresh copy. - // - for (auto file : translationUnit->getSourceFiles()) - { - if (file->getPathInfo().hasFoundPath() && - Path::equals(file->getPathInfo().foundPath, sourceFile->getPathInfo().foundPath)) - return result; - } - } - - module->addFileDependency(sourceFile); - - // Create a transparent FileDecl to hold all children from the included file. - auto fileDecl = module->getASTBuilder()->create<FileDecl>(); - fileDecl->nameAndLoc.name = name; - fileDecl->parentDecl = module->getModuleDecl(); - module->getIncludedSourceFileMap().add(sourceFile, fileDecl); - - FrontEndPreprocessorHandler preprocessorHandler( - module, - module->getASTBuilder(), - sink, - translationUnit); - auto combinedPreprocessorDefinitions = translationUnit->getCombinedPreprocessorDefinitions(); - SourceLanguage sourceLanguage = translationUnit->sourceLanguage; - SlangLanguageVersion slangLanguageVersion = module->getModuleDecl()->languageVersion; - auto tokens = preprocessSource( - sourceFile, - sink, - &includeSystem, - combinedPreprocessorDefinitions, - this, - sourceLanguage, - slangLanguageVersion, - &preprocessorHandler); - - if (sourceLanguage == SourceLanguage::Unknown) - sourceLanguage = translationUnit->sourceLanguage; - - if (slangLanguageVersion != module->getModuleDecl()->languageVersion) - { - sink->diagnose( - tokens.begin()->getLoc(), - Diagnostics::languageVersionDiffersFromIncludingModule); - } - - auto outerScope = module->getModuleDecl()->ownedScope; - parseSourceFile( - module->getASTBuilder(), - translationUnit, - sourceLanguage, - tokens, - sink, - outerScope, - fileDecl); - - module->getModuleDecl()->addMember(fileDecl); - - result.fileDecl = fileDecl; - result.isNew = true; - return result; -} - -// -// ModuleDependencyList -// - -void ModuleDependencyList::addDependency(Module* module) -{ - // If we depend on a module, then we depend on everything it depends on. - // - // Note: We are processing these sub-depenencies before adding the - // `module` itself, so that in the common case a module will always - // appear *after* everything it depends on. - // - // However, this rule is being violated in the compiler right now because - // the modules for hte top-level translation units of a compile request - // will be added to the list first (using `addLeafDependency`) to - // maintain compatibility with old behavior. This may be fixed later. - // - for (auto subDependency : module->getModuleDependencyList()) - { - _addDependency(subDependency); - } - _addDependency(module); -} - -void ModuleDependencyList::addLeafDependency(Module* module) -{ - _addDependency(module); -} - -void ModuleDependencyList::_addDependency(Module* module) -{ - if (m_moduleSet.contains(module)) - return; - - m_moduleList.add(module); - m_moduleSet.add(module); -} - -// -// FileDependencyList -// - -void FileDependencyList::addDependency(SourceFile* sourceFile) -{ - if (m_fileSet.contains(sourceFile)) - return; - - m_fileList.add(sourceFile); - m_fileSet.add(sourceFile); -} - -void FileDependencyList::addDependency(Module* module) -{ - for (SourceFile* sourceFile : module->getFileDependencyList()) - { - addDependency(sourceFile); - } -} - -// -// Module -// - -Module::Module(Linkage* linkage, ASTBuilder* astBuilder) - : ComponentType(linkage), m_mangledExportPool(StringSlicePool::Style::Empty) -{ - if (astBuilder) - { - m_astBuilder = astBuilder; - } - else - { - m_astBuilder = linkage->getASTBuilder(); - } - getOptionSet() = linkage->m_optionSet; - addModuleDependency(this); -} - -ISlangUnknown* Module::getInterface(const Guid& guid) -{ - if (guid == IModule::getTypeGuid()) - return asExternal(this); - if (guid == IModulePrecompileService_Experimental::getTypeGuid()) - return static_cast<slang::IModulePrecompileService_Experimental*>(this); - return Super::getInterface(guid); -} - -void Module::buildHash(DigestBuilder<SHA1>& builder) -{ - builder.append(computeDigest()); -} - -slang::DeclReflection* Module::getModuleReflection() -{ - return (slang::DeclReflection*)m_moduleDecl; -} - -SHA1::Digest Module::computeDigest() -{ - if (m_digest == SHA1::Digest()) - { - DigestBuilder<SHA1> digestBuilder; - auto version = String(getBuildTagString()); - digestBuilder.append(version); - getOptionSet().buildHash(digestBuilder); - - auto fileDependencies = getFileDependencies(); - - for (auto file : fileDependencies) - { - digestBuilder.append(file->getDigest()); - } - m_digest = digestBuilder.finalize(); - } - return m_digest; -} - -void Module::addModuleDependency(Module* module) -{ - m_moduleDependencyList.addDependency(module); - m_fileDependencyList.addDependency(module); -} - -void Module::addFileDependency(SourceFile* sourceFile) -{ - m_fileDependencyList.addDependency(sourceFile); -} - -void Module::setModuleDecl(ModuleDecl* moduleDecl) -{ - m_moduleDecl = moduleDecl; - moduleDecl->module = this; -} - -void Module::setName(String name) -{ - m_name = getLinkage()->getNamePool()->getName(name); -} - - -RefPtr<EntryPoint> Module::findEntryPointByName(UnownedStringSlice const& name) -{ - for (auto entryPoint : m_entryPoints) - { - if (entryPoint->getName()->text.getUnownedSlice() == name) - return entryPoint; - } - - return nullptr; -} - -RefPtr<EntryPoint> Module::findAndCheckEntryPoint( - UnownedStringSlice const& name, - SlangStage stage, - ISlangBlob** outDiagnostics) -{ - // If there is already an entrypoint marked with the [shader] attribute, - // we should just return that. - // - if (auto existingEntryPoint = findEntryPointByName(name)) - return existingEntryPoint; - - SLANG_AST_BUILDER_RAII(m_astBuilder); - - // If the function hasn't been marked as [shader], then it won't be discovered - // by findEntryPointByName. We need to route this to the `findAndValidateEntryPoint` - // function. To do that we need to setup a FrontEndCompileRequest and a - // FrontEndEntryPointRequest. - // - DiagnosticSink sink(getLinkage()->getSourceManager(), DiagnosticSink::SourceLocationLexer()); - FrontEndCompileRequest frontEndRequest(getLinkage(), StdWriters::getSingleton(), &sink); - RefPtr<TranslationUnitRequest> tuRequest = new TranslationUnitRequest(&frontEndRequest); - tuRequest->module = this; - tuRequest->moduleName = m_name; - frontEndRequest.translationUnits.add(tuRequest); - FrontEndEntryPointRequest entryPointRequest( - &frontEndRequest, - 0, - getLinkage()->getNamePool()->getName(name), - Profile((Stage)stage)); - auto result = findAndValidateEntryPoint(&entryPointRequest); - if (outDiagnostics) - { - sink.getBlobIfNeeded(outDiagnostics); - } - return result; -} - -void Module::_addEntryPoint(EntryPoint* entryPoint) -{ - m_entryPoints.add(entryPoint); -} - -static bool _canExportDeclSymbol(ASTNodeType type) -{ - switch (type) - { - case ASTNodeType::EmptyDecl: - { - return false; - } - default: - break; - } - - return true; -} - -static bool _canRecurseExportSymbol(Decl* decl) -{ - if (as<FunctionDeclBase>(decl) || as<ScopeDecl>(decl)) - { - return false; - } - return true; -} - -void Module::_processFindDeclsExportSymbolsRec(Decl* decl) -{ - if (_canExportDeclSymbol(decl->astNodeType)) - { - // It's a reference to a declaration in another module, so first get the symbol name. - String mangledName = getMangledName(getCurrentASTBuilder(), decl); - - Index index = Index(m_mangledExportPool.add(mangledName)); - - // TODO(JS): It appears that more than one entity might have the same mangled name. - // So for now we ignore and just take the first one. - if (index == m_mangledExportSymbols.getCount()) - { - m_mangledExportSymbols.add(decl); - } - } - - if (!_canRecurseExportSymbol(decl)) - { - // We don't need to recurse any further into this - return; - } - - // If it's a container process it's children - if (auto containerDecl = as<ContainerDecl>(decl)) - { - for (auto child : containerDecl->getDirectMemberDecls()) - { - _processFindDeclsExportSymbolsRec(child); - } - } - - // GenericDecl is also a container, so do subsequent test - if (auto genericDecl = as<GenericDecl>(decl)) - { - _processFindDeclsExportSymbolsRec(genericDecl->inner); - } -} - -Decl* Module::findExportedDeclByMangledName(const UnownedStringSlice& mangledName) -{ - // If this module is a serialized module that is being - // deserialized on-demand, then we want to use the - // mangled name mapping that was baked into the serialized - // data, rather than attempt to enumerate all of the declarations - // in the module (as would be done if we proceeded to call - // `ensureExportLookupAcceleratorBuilt()`). - // - if (this->m_moduleDecl->isUsingOnDemandDeserializationForExports()) - { - return m_moduleDecl->_findSerializedDeclByMangledExportName(mangledName); - } - - ensureExportLookupAcceleratorBuilt(); - - const Index index = m_mangledExportPool.findIndex(mangledName); - return (index >= 0) ? m_mangledExportSymbols[index] : nullptr; -} - -void Module::ensureExportLookupAcceleratorBuilt() -{ - // Will be non zero if has been previously attempted - if (m_mangledExportSymbols.getCount() == 0) - { - // Build up the exported mangled name list - _processFindDeclsExportSymbolsRec(getModuleDecl()); - - // If nothing found, mark that we have tried looking by making - // m_mangledExportSymbols.getCount() != 0 - if (m_mangledExportSymbols.getCount() == 0) - { - m_mangledExportSymbols.add(nullptr); - } - } -} - -Count Module::getExportedDeclCount() -{ - ensureExportLookupAcceleratorBuilt(); - - return m_mangledExportPool.getSlicesCount(); -} - -Decl* Module::getExportedDecl(Index index) -{ - ensureExportLookupAcceleratorBuilt(); - return m_mangledExportSymbols[index]; -} - -UnownedStringSlice Module::getExportedDeclMangledName(Index index) -{ - ensureExportLookupAcceleratorBuilt(); - return m_mangledExportPool.getSlices()[index]; -} - -// ComponentType - -ComponentType::ComponentType(Linkage* linkage) - : m_linkage(linkage) -{ -} - -ComponentType* asInternal(slang::IComponentType* inComponentType) -{ - // Note: we use a `queryInterface` here instead of just a `static_cast` - // to ensure that the `IComponentType` we get is the preferred/canonical - // one, which shares its address with the `ComponentType`. - // - // TODO: An alternative choice here would be to have a "magic" IID that - // we pass into `queryInterface` that returns the `ComponentType` directly - // (without even `addRef`-ing it). - // - ComPtr<slang::IComponentType> componentType; - inComponentType->queryInterface(SLANG_IID_PPV_ARGS(componentType.writeRef())); - return static_cast<ComponentType*>(componentType.get()); -} - -ISlangUnknown* ComponentType::getInterface(Guid const& guid) -{ - if (guid == ISlangUnknown::getTypeGuid() || guid == slang::IComponentType::getTypeGuid()) - { - return static_cast<slang::IComponentType*>(this); - } - if (guid == IModulePrecompileService_Experimental::getTypeGuid()) - return static_cast<slang::IModulePrecompileService_Experimental*>(this); - if (guid == IComponentType2::getTypeGuid()) - return static_cast<slang::IComponentType2*>(this); - return nullptr; -} - -SLANG_NO_THROW slang::ISession* SLANG_MCALL ComponentType::getSession() -{ - return m_linkage; -} - -SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL -ComponentType::getLayout(Int targetIndex, slang::IBlob** outDiagnostics) -{ - auto linkage = getLinkage(); - if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) - return nullptr; - auto target = linkage->targets[targetIndex]; - - DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - auto programLayout = getTargetProgram(target)->getOrCreateLayout(&sink); - sink.getBlobIfNeeded(outDiagnostics); - - return asExternal(programLayout); -} - -static ICastable* _findDiagnosticRepresentation(IArtifact* artifact) -{ - if (auto rep = findAssociatedRepresentation<IArtifactDiagnostics>(artifact)) - { - return rep; - } - - for (auto associated : artifact->getAssociated()) - { - if (isDerivedFrom(associated->getDesc().payload, ArtifactPayload::Diagnostics)) - { - return associated; - } - } - return nullptr; -} - -static IArtifact* _findObfuscatedSourceMap(IArtifact* artifact) -{ - // If we find any obfuscated source maps, we are done - for (auto associated : artifact->getAssociated()) - { - const auto desc = associated->getDesc(); - - if (isDerivedFrom(desc.payload, ArtifactPayload::SourceMap) && - isDerivedFrom(desc.style, ArtifactStyle::Obfuscated)) - { - return associated; - } - } - return nullptr; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getResultAsFileSystem( - SlangInt entryPointIndex, - Int targetIndex, - ISlangMutableFileSystem** outFileSystem) -{ - ComPtr<ISlangBlob> diagnostics; - ComPtr<ISlangBlob> code; - - SLANG_RETURN_ON_FAIL( - getEntryPointCode(entryPointIndex, targetIndex, diagnostics.writeRef(), code.writeRef())); - - auto linkage = getLinkage(); - - auto target = linkage->targets[targetIndex]; - - auto targetProgram = getTargetProgram(target); - - IArtifact* artifact = targetProgram->getExistingEntryPointResult(entryPointIndex); - - // Add diagnostics id needs be... - if (diagnostics && !_findDiagnosticRepresentation(artifact)) - { - // Add as an associated - - auto diagnosticsArtifact = Artifact::create( - ArtifactDesc::make(Artifact::Kind::HumanText, ArtifactPayload::Diagnostics)); - diagnosticsArtifact->addRepresentationUnknown(diagnostics); - - artifact->addAssociated(diagnosticsArtifact); - - SLANG_ASSERT(diagnosticsArtifact == _findDiagnosticRepresentation(artifact)); - } - - // Add obfuscated source maps - if (!_findObfuscatedSourceMap(artifact)) - { - List<IRModule*> irModules; - enumerateIRModules([&](IRModule* irModule) -> void { irModules.add(irModule); }); - - for (auto irModule : irModules) - { - if (auto obfuscatedSourceMap = irModule->getObfuscatedSourceMap()) - { - auto artifactDesc = ArtifactDesc::make( - ArtifactKind::Json, - ArtifactPayload::SourceMap, - ArtifactStyle::Obfuscated); - - // Create the source map artifact - auto sourceMapArtifact = Artifact::create( - artifactDesc, - obfuscatedSourceMap->get().m_file.getUnownedSlice()); - - sourceMapArtifact->addRepresentation(obfuscatedSourceMap); - - // associate with the artifact - artifact->addAssociated(sourceMapArtifact); - } - } - } - - // Turn into a file system and return - ComPtr<ISlangMutableFileSystem> fileSystem(new MemoryFileSystem); - - // Filter the containerArtifact into things that can be written - ComPtr<IArtifact> writeArtifact; - SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::filter(artifact, writeArtifact)); - SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::writeContainer(writeArtifact, "", fileSystem)); - - *outFileSystem = fileSystem.detach(); - - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCode( - SlangInt entryPointIndex, - Int targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) -{ - auto linkage = getLinkage(); - if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) - return SLANG_E_INVALID_ARG; - auto target = linkage->targets[targetIndex]; - - auto targetProgram = getTargetProgram(target); - - DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - - IArtifact* artifact = targetProgram->getOrCreateEntryPointResult(entryPointIndex, &sink); - sink.getBlobIfNeeded(outDiagnostics); - - if (artifact == nullptr) - return SLANG_FAIL; - - return artifact->loadBlob(ArtifactKeep::Yes, outCode); -} - -SLANG_NO_THROW void SLANG_MCALL ComponentType::getEntryPointHash( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outHash) -{ - DigestBuilder<SHA1> builder; - - // A note on enums that may be hashed in as part of the following two function calls: - // - // While enums are not guaranteed to be encoded the same way across all versions of - // the compiler, part of hashing the linkage is hashing in the compiler version. - // Consequently, any encoding differences as a result of different compiler versions - // will already be reflected in the resulting hash. - getLinkage()->buildHash(builder, targetIndex); - - buildHash(builder); - - // Add the name and name override for the specified entry point to the hash. - auto entryPoint = getEntryPoint(entryPointIndex); - if (entryPoint) - { - auto entryPointName = entryPoint->getName()->text; - builder.append(entryPointName); - auto entryPointMangledName = getEntryPointMangledName(entryPointIndex); - builder.append(entryPointMangledName); - auto entryPointNameOverride = getEntryPointNameOverride(entryPointIndex); - builder.append(entryPointNameOverride); - } - - auto hash = builder.finalize().toBlob(); - *outHash = hash.detach(); -} - -SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) -{ - auto linkage = getLinkage(); - if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) - return SLANG_E_INVALID_ARG; - auto target = linkage->targets[targetIndex]; - - auto targetProgram = getTargetProgram(target); - - DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - - IArtifact* artifact = targetProgram->getOrCreateEntryPointResult(entryPointIndex, &sink); - sink.getBlobIfNeeded(outDiagnostics); - - if (artifact == nullptr) - return SLANG_FAIL; - - return artifact->loadSharedLibrary(ArtifactKeep::Yes, outSharedLibrary); -} - -SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointMetadata( - SlangInt entryPointIndex, - Int targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) -{ - auto linkage = getLinkage(); - if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) - return SLANG_E_INVALID_ARG; - auto target = linkage->targets[targetIndex]; - - auto targetProgram = getTargetProgram(target); - - DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - - IArtifact* artifact = targetProgram->getOrCreateEntryPointResult(entryPointIndex, &sink); - sink.getBlobIfNeeded(outDiagnostics); - - if (artifact == nullptr) - return SLANG_E_NOT_AVAILABLE; - - auto metadata = findAssociatedRepresentation<IArtifactPostEmitMetadata>(artifact); - if (!metadata) - return SLANG_E_NOT_AVAILABLE; - - *outMetadata = static_cast<slang::IMetadata*>(metadata); - (*outMetadata)->addRef(); - return SLANG_OK; -} - -RefPtr<ComponentType> ComponentType::specialize( - SpecializationArg const* inSpecializationArgs, - SlangInt specializationArgCount, - DiagnosticSink* sink) -{ - if (specializationArgCount == 0) - { - return this; - } - - List<SpecializationArg> specializationArgs; - specializationArgs.addRange(inSpecializationArgs, specializationArgCount); - - // We next need to validate that the specialization arguments - // make sense, and also expand them to include any derived data - // (e.g., interface conformance witnesses) that doesn't get - // passed explicitly through the API interface. - // - RefPtr<SpecializationInfo> specializationInfo = - _validateSpecializationArgs(specializationArgs.getBuffer(), specializationArgCount, sink); - - return new SpecializedComponentType(this, specializationInfo, specializationArgs, sink); -} - -SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) -{ - DiagnosticSink sink(getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); - - // First let's check if the number of arguments given matches - // the number of parameters that are present on this component type. - // - auto specializationParamCount = getSpecializationParamCount(); - if (specializationArgCount != specializationParamCount) - { - sink.diagnose( - SourceLoc(), - Diagnostics::mismatchSpecializationArguments, - specializationParamCount, - specializationArgCount); - sink.getBlobIfNeeded(outDiagnostics); - return SLANG_FAIL; - } - - List<SpecializationArg> expandedArgs; - for (Int aa = 0; aa < specializationArgCount; ++aa) - { - auto apiArg = specializationArgs[aa]; - - SpecializationArg expandedArg; - switch (apiArg.kind) - { - case slang::SpecializationArg::Kind::Type: - expandedArg.val = asInternal(apiArg.type); - break; - - default: - sink.getBlobIfNeeded(outDiagnostics); - return SLANG_FAIL; - } - expandedArgs.add(expandedArg); - } - - auto specializedComponentType = - specialize(expandedArgs.getBuffer(), expandedArgs.getCount(), &sink); - - sink.getBlobIfNeeded(outDiagnostics); - - *outSpecializedComponentType = specializedComponentType.detach(); - - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL -ComponentType::renameEntryPoint(const char* newName, IComponentType** outEntryPoint) -{ - RefPtr<RenamedEntryPointComponentType> result = - new RenamedEntryPointComponentType(this, newName); - *outEntryPoint = result.detach(); - return SLANG_OK; -} - -RefPtr<ComponentType> fillRequirements(ComponentType* inComponentType); - -SLANG_NO_THROW SlangResult SLANG_MCALL -ComponentType::link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) -{ - // TODO: It should be possible for `fillRequirements` to fail, - // in cases where we have a dependency that can't be automatically - // resolved. - // - SLANG_UNUSED(outDiagnostics); - - DiagnosticSink sink(getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); - - try - { - auto linked = fillRequirements(this); - if (!linked) - return SLANG_FAIL; - - *outLinkedComponentType = ComPtr<slang::IComponentType>(linked).detach(); - return SLANG_OK; - } - catch (const AbortCompilationException& e) - { - outputExceptionDiagnostic(e, sink, outDiagnostics); - return SLANG_FAIL; - } - catch (const Exception& e) - { - outputExceptionDiagnostic(e, sink, outDiagnostics); - return SLANG_FAIL; - } - catch (...) - { - outputExceptionDiagnostic(sink, outDiagnostics); - return SLANG_FAIL; - } -} - -SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::linkWithOptions( - slang::IComponentType** outLinkedComponentType, - uint32_t count, - slang::CompilerOptionEntry* entries, - ISlangBlob** outDiagnostics) -{ - SLANG_RETURN_ON_FAIL(link(outLinkedComponentType, outDiagnostics)); - - auto linked = *outLinkedComponentType; - - if (linked) - { - static_cast<ComponentType*>(linked)->getOptionSet().load(count, entries); - } - - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCompileResult( - SlangInt entryPointIndex, - Int targetIndex, - slang::ICompileResult** outCompileResult, - slang::IBlob** outDiagnostics) -{ - auto linkage = getLinkage(); - if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) - return SLANG_E_INVALID_ARG; - auto target = linkage->targets[targetIndex]; - - auto targetProgram = getTargetProgram(target); - - DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - - IArtifact* artifact = targetProgram->getOrCreateEntryPointResult(entryPointIndex, &sink); - sink.getBlobIfNeeded(outDiagnostics); - if (artifact == nullptr) - return SLANG_E_NOT_AVAILABLE; - - *outCompileResult = static_cast<slang::ICompileResult*>(artifact); - (*outCompileResult)->addRef(); - return SLANG_OK; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetCompileResult( - Int targetIndex, - slang::ICompileResult** outCompileResult, - slang::IBlob** outDiagnostics) -{ - IArtifact* artifact = getTargetArtifact(targetIndex, outDiagnostics); - if (artifact == nullptr) - return SLANG_E_NOT_AVAILABLE; - - *outCompileResult = static_cast<slang::ICompileResult*>(artifact); - (*outCompileResult)->addRef(); - return SLANG_OK; -} - -/// Visitor used by `ComponentType::enumerateModules` -struct EnumerateModulesVisitor : ComponentTypeVisitor -{ - EnumerateModulesVisitor(ComponentType::EnumerateModulesCallback callback, void* userData) - : m_callback(callback), m_userData(userData) - { - } - - ComponentType::EnumerateModulesCallback m_callback; - void* m_userData; - - void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} - - void visitRenamedEntryPoint( - RenamedEntryPointComponentType* entryPoint, - EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE - { - entryPoint->getBase()->acceptVisitor(this, specializationInfo); - } - - void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE - { - m_callback(module, m_userData); - } - - void visitComposite( - CompositeComponentType* composite, - CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE - { - visitChildren(composite, specializationInfo); - } - - void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE - { - visitChildren(specialized); - } - - void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE - { - SLANG_UNUSED(conformance); - } -}; - - -void ComponentType::enumerateModules(EnumerateModulesCallback callback, void* userData) -{ - EnumerateModulesVisitor visitor(callback, userData); - acceptVisitor(&visitor, nullptr); -} - -/// Visitor used by `ComponentType::enumerateIRModules` -struct EnumerateIRModulesVisitor : ComponentTypeVisitor -{ - EnumerateIRModulesVisitor(ComponentType::EnumerateIRModulesCallback callback, void* userData) - : m_callback(callback), m_userData(userData) - { - } - - ComponentType::EnumerateIRModulesCallback m_callback; - void* m_userData; - - void visitEntryPoint(EntryPoint*, EntryPoint::EntryPointSpecializationInfo*) SLANG_OVERRIDE {} - - void visitRenamedEntryPoint( - RenamedEntryPointComponentType* entryPoint, - EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE - { - entryPoint->getBase()->acceptVisitor(this, specializationInfo); - } - - void visitModule(Module* module, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE - { - m_callback(module->getIRModule(), m_userData); - } - - void visitComposite( - CompositeComponentType* composite, - CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE - { - visitChildren(composite, specializationInfo); - } - - void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE - { - visitChildren(specialized); - - m_callback(specialized->getIRModule(), m_userData); - } - - void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE - { - m_callback(conformance->getIRModule(), m_userData); - } -}; - -void ComponentType::enumerateIRModules(EnumerateIRModulesCallback callback, void* userData) -{ - EnumerateIRModulesVisitor visitor(callback, userData); - acceptVisitor(&visitor, nullptr); -} - -IArtifact* ComponentType::getTargetArtifact(Int targetIndex, slang::IBlob** outDiagnostics) -{ - auto linkage = getLinkage(); - if (targetIndex < 0 || targetIndex >= linkage->targets.getCount()) - return nullptr; - ComPtr<IArtifact> artifact; - if (m_targetArtifacts.tryGetValue(targetIndex, artifact)) - { - return artifact.get(); - } - - // If the user hasn't specified any entry points, then we should - // discover all entrypoints that are defined in linked modules, and - // include all of them in the compile. - // - if (getEntryPointCount() == 0) - { - List<Module*> modules; - this->enumerateModules([&](Module* module) { modules.add(module); }); - List<RefPtr<ComponentType>> components; - components.add(this); - bool entryPointsDiscovered = false; - for (auto module : modules) - { - for (auto entryPoint : module->getEntryPoints()) - { - components.add(entryPoint); - entryPointsDiscovered = true; - } - } - - // If any entry points were discovered, then we should emit the program with entrypoints - // linked. - if (entryPointsDiscovered) - { - RefPtr<CompositeComponentType> composite = - new CompositeComponentType(linkage, components); - ComPtr<IComponentType> linkedComponentType; - SLANG_RETURN_NULL_ON_FAIL( - composite->link(linkedComponentType.writeRef(), outDiagnostics)); - auto targetArtifact = static_cast<ComponentType*>(linkedComponentType.get()) - ->getTargetArtifact(targetIndex, outDiagnostics); - if (targetArtifact) - { - m_targetArtifacts[targetIndex] = targetArtifact; - } - return targetArtifact; - } - } - - auto target = linkage->targets[targetIndex]; - auto targetProgram = getTargetProgram(target); - - DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - - IArtifact* targetArtifact = targetProgram->getOrCreateWholeProgramResult(&sink); - sink.getBlobIfNeeded(outDiagnostics); - m_targetArtifacts[targetIndex] = ComPtr<IArtifact>(targetArtifact); - return targetArtifact; -} - -SLANG_NO_THROW SlangResult SLANG_MCALL -ComponentType::getTargetCode(Int targetIndex, slang::IBlob** outCode, slang::IBlob** outDiagnostics) -{ - IArtifact* artifact = getTargetArtifact(targetIndex, outDiagnostics); - - if (artifact == nullptr) - return SLANG_FAIL; - - return artifact->loadBlob(ArtifactKeep::Yes, outCode); -} - -SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetMetadata( - Int targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) -{ - IArtifact* artifact = getTargetArtifact(targetIndex, outDiagnostics); - - if (artifact == nullptr) - return SLANG_FAIL; - - auto metadata = findAssociatedRepresentation<IArtifactPostEmitMetadata>(artifact); - if (!metadata) - return SLANG_E_NOT_AVAILABLE; - *outMetadata = static_cast<slang::IMetadata*>(metadata); - (*outMetadata)->addRef(); - return SLANG_OK; -} - -// -// CompositeComponentType -// - -RefPtr<ComponentType> CompositeComponentType::create( - Linkage* linkage, - List<RefPtr<ComponentType>> const& childComponents) -{ - // TODO: We should ideally be caching the results of - // composition on the `linkage`, so that if we get - // asked for the same composite again later we re-use - // it rather than re-create it. - // - // Similarly, we might want to do some amount of - // work to "canonicalize" the input for composition. - // E.g., if the user does: - // - // X = compose(A,B); - // Y = compose(C,D); - // Z = compose(X,Y); - // - // W = compose(A, B, C, D); - // - // Then there is no observable difference between - // Z and W, so we might prefer to have them be identical. - - // If there is only a single child, then we should - // just return that child rather than create a dummy composite. - // - if (childComponents.getCount() == 1) - { - return childComponents[0]; - } - - return new CompositeComponentType(linkage, childComponents); -} - - -CompositeComponentType::CompositeComponentType( - Linkage* linkage, - List<RefPtr<ComponentType>> const& childComponents) - : ComponentType(linkage), m_childComponents(childComponents) -{ - HashSet<ComponentType*> requirementsSet; - for (auto child : childComponents) - { - child->enumerateModules([&](Module* module) { requirementsSet.add(module); }); - } - - for (auto child : childComponents) - { - auto childEntryPointCount = child->getEntryPointCount(); - for (Index cc = 0; cc < childEntryPointCount; ++cc) - { - m_entryPoints.add(child->getEntryPoint(cc)); - m_entryPointMangledNames.add(child->getEntryPointMangledName(cc)); - m_entryPointNameOverrides.add(child->getEntryPointNameOverride(cc)); - } - - auto childShaderParamCount = child->getShaderParamCount(); - for (Index pp = 0; pp < childShaderParamCount; ++pp) - { - m_shaderParams.add(child->getShaderParam(pp)); - } - - auto childSpecializationParamCount = child->getSpecializationParamCount(); - for (Index pp = 0; pp < childSpecializationParamCount; ++pp) - { - m_specializationParams.add(child->getSpecializationParam(pp)); - } - - for (auto module : child->getModuleDependencies()) - { - m_moduleDependencyList.addDependency(module); - } - for (auto sourceFile : child->getFileDependencies()) - { - m_fileDependencyList.addDependency(sourceFile); - } - - auto childRequirementCount = child->getRequirementCount(); - for (Index rr = 0; rr < childRequirementCount; ++rr) - { - auto childRequirement = child->getRequirement(rr); - if (!requirementsSet.contains(childRequirement)) - { - requirementsSet.add(childRequirement); - m_requirements.add(childRequirement); - } - } - } -} - -void CompositeComponentType::buildHash(DigestBuilder<SHA1>& builder) -{ - auto componentCount = getChildComponentCount(); - - for (Index i = 0; i < componentCount; ++i) - { - getChildComponent(i)->buildHash(builder); - } -} - -Index CompositeComponentType::getEntryPointCount() -{ - return m_entryPoints.getCount(); -} - -RefPtr<EntryPoint> CompositeComponentType::getEntryPoint(Index index) -{ - return m_entryPoints[index]; -} - -String CompositeComponentType::getEntryPointMangledName(Index index) -{ - return m_entryPointMangledNames[index]; -} - -String CompositeComponentType::getEntryPointNameOverride(Index index) -{ - return m_entryPointNameOverrides[index]; -} - -Index CompositeComponentType::getShaderParamCount() -{ - return m_shaderParams.getCount(); -} - -ShaderParamInfo CompositeComponentType::getShaderParam(Index index) -{ - return m_shaderParams[index]; -} - -Index CompositeComponentType::getSpecializationParamCount() -{ - return m_specializationParams.getCount(); -} - -SpecializationParam const& CompositeComponentType::getSpecializationParam(Index index) -{ - return m_specializationParams[index]; -} - -Index CompositeComponentType::getRequirementCount() -{ - return m_requirements.getCount(); -} - -RefPtr<ComponentType> CompositeComponentType::getRequirement(Index index) -{ - return m_requirements[index]; -} - -List<Module*> const& CompositeComponentType::getModuleDependencies() -{ - return m_moduleDependencyList.getModuleList(); -} - -List<SourceFile*> const& CompositeComponentType::getFileDependencies() -{ - return m_fileDependencyList.getFileList(); -} - -void CompositeComponentType::acceptVisitor( - ComponentTypeVisitor* visitor, - SpecializationInfo* specializationInfo) -{ - visitor->visitComposite(this, as<CompositeSpecializationInfo>(specializationInfo)); -} - -RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) -{ - SLANG_UNUSED(argCount); - - RefPtr<CompositeSpecializationInfo> specializationInfo = new CompositeSpecializationInfo(); - - Index offset = 0; - for (auto child : m_childComponents) - { - auto childParamCount = child->getSpecializationParamCount(); - SLANG_ASSERT(offset + childParamCount <= argCount); - - auto childInfo = child->_validateSpecializationArgs(args + offset, childParamCount, sink); - - specializationInfo->childInfos.add(childInfo); - - offset += childParamCount; - } - return specializationInfo; -} - -// -// SpecializedComponentType -// - -/// Utility type for collecting modules references by types/declarations -struct SpecializationArgModuleCollector : ComponentTypeVisitor -{ - HashSet<Module*> m_modulesSet; - List<Module*> m_modulesList; - - void addModule(Module* module) - { - m_modulesList.add(module); - m_modulesSet.add(module); - } - - void maybeAddModule(Module* module) - { - if (!module) - return; - if (m_modulesSet.contains(module)) - return; - - addModule(module); - } - - void collectReferencedModules(Decl* decl) - { - auto module = getModule(decl); - maybeAddModule(module); - } - - void collectReferencedModules(SubstitutionSet substitutions) - { - substitutions.forEachGenericSubstitution( - [this](GenericDecl*, Val::OperandView<Val> args) - { - for (auto arg : args) - { - collectReferencedModules(arg); - } - }); - } - - void collectReferencedModules(DeclRefBase* declRef) - { - collectReferencedModules(declRef->getDecl()); - collectReferencedModules(SubstitutionSet(declRef)); - } - - void collectReferencedModules(Type* type) - { - if (auto declRefType = as<DeclRefType>(type)) - { - collectReferencedModules(declRefType->getDeclRef()); - } - - // TODO: Handle non-decl-ref composite type cases - // (e.g., function types). - } - - void collectReferencedModules(Val* val) - { - if (auto type = as<Type>(val)) - { - collectReferencedModules(type); - } - else if (auto declRefVal = as<DeclRefIntVal>(val)) - { - collectReferencedModules(declRefVal->getDeclRef()); - } - - // TODO: other cases of values that could reference - // a declaration. - } - - void collectReferencedModules(List<ExpandedSpecializationArg> const& args) - { - for (auto arg : args) - { - collectReferencedModules(arg.val); - collectReferencedModules(arg.witness); - } - } - - // - // ComponentTypeVisitor methods - // - - void visitEntryPoint( - EntryPoint* entryPoint, - EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE - { - SLANG_UNUSED(entryPoint); - - if (!specializationInfo) - return; - - collectReferencedModules(specializationInfo->specializedFuncDeclRef); - collectReferencedModules(specializationInfo->existentialSpecializationArgs); - } - - void visitRenamedEntryPoint( - RenamedEntryPointComponentType* entryPoint, - EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE - { - entryPoint->getBase()->acceptVisitor(this, specializationInfo); - } - - void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) - SLANG_OVERRIDE - { - SLANG_UNUSED(module); - - if (!specializationInfo) - return; - - for (auto arg : specializationInfo->genericArgs) - { - collectReferencedModules(arg.argVal); - } - collectReferencedModules(specializationInfo->existentialArgs); - } - - void visitComposite( - CompositeComponentType* composite, - CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE - { - visitChildren(composite, specializationInfo); - } - - void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE - { - visitChildren(specialized); - } - - void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE - { - SLANG_UNUSED(conformance); - } -}; - -SpecializedComponentType::SpecializedComponentType( - ComponentType* base, - ComponentType::SpecializationInfo* specializationInfo, - List<SpecializationArg> const& specializationArgs, - DiagnosticSink* sink) - : ComponentType(base->getLinkage()) - , m_base(base) - , m_specializationInfo(specializationInfo) - , m_specializationArgs(specializationArgs) -{ - m_optionSet.overrideWith(base->getOptionSet()); - - m_irModule = generateIRForSpecializedComponentType(this, sink); - - // We need to account for the fact that a specialized - // entity like `myShader<SomeType>` needs to not only - // depend on the module(s) that `myShader` depends on, - // but also on any modules that `SomeType` depends on. - // - // We will set up a "collector" type that will be - // used to build a list of these additional modules. - // - SpecializationArgModuleCollector moduleCollector; - - // We don't want to go adding additional requirements for - // modules that the base component type already includes, - // so we will add those to the set of modules in - // the collector before we starting trying to add others. - // - base->enumerateModules([&](Module* module) { moduleCollector.m_modulesSet.add(module); }); - - // In order to collect the additional modules, we need - // to inspect the specialization arguments and see what - // they depend on. - // - // Naively, it seems like we'd just want to iterate - // over `specializationArgs`, which gives the specialization - // arguments as the user supplied them. However, such - // an approach would have a subtle problem. - // - // If we have a generic entry point like: - // - // // In module A - // myShader<T : IThing> - // - // - // And the type `SomeType` that is being used as an argument doesn't - // directly conform to `IThing`: - // - // // In module B - // struct SomeType { ... } - // - // and the conformance of `SomeType` to `IThing` is - // coming from yet another module: - // - // // In module C - // import B; - // extension SomeType : IThing { ... } - // - // In this case, the specialized component for `myShader<SomeType>` - // needs to depend on all of: - // - // * Module A, because it defines `myShader` - // * Module B, because it defines `SomeType` - // * Module C, because it defines the conformance `SomeType : IThing` - // - // We thus need to iterate over a form of the specialization - // arguments that includes the "expanded" arguments like - // interface conformance witnesses that got added during - // semantic checking. - // - // The expanded arguments are being stored in the `specializationInfo` - // today (for use by downstream code generation), and the easiest - // way to walk that information and get to the leaf nodes where - // the expanded arguments are stored is to apply a visitor to - // the specialized component type we are in the middle of constructing. - // - moduleCollector.visitSpecialized(this); - - // Now that we've collected our additional information, we can - // start to build up the final lists for the specialized component type. - // - // The starting point for our lists comes from the base component type. - // - m_moduleDependencies = base->getModuleDependencies(); - m_fileDependencies = base->getFileDependencies(); - - Index baseRequirementCount = base->getRequirementCount(); - for (Index r = 0; r < baseRequirementCount; r++) - { - m_requirements.add(base->getRequirement(r)); - } - - // The specialized component type will need to have additional - // dependencies and requirements based on the modules that - // were collected when looking at the specialization arguments. - - // We want to avoid adding the same file dependency more than once. - // - HashSet<SourceFile*> fileDependencySet; - for (SourceFile* sourceFile : m_fileDependencies) - fileDependencySet.add(sourceFile); - - for (auto module : moduleCollector.m_modulesList) - { - // The specialized component type will have an open (unsatisfied) - // requirement for each of the modules that its specialization - // arguments need. - // - // Note: what this means in practice is that the component type - // records that the given module(s) will need to be linked in - // before final code can be generated, but it importantly - // does not dictate the final placement of the parameters from - // those modules in the layout. - // - m_requirements.add(module); - - // The speciialized component type will also have a dependency - // on all the files that any of the modules involved in - // it depend on (including those that are required but not - // yet linked in). - // - // The file path information is what a client would need to - // use to decide if kernel code is out of date compared to - // source files, so we want to include anything that could - // affect the validity of generated code. - // - for (SourceFile* sourceFile : module->getFileDependencies()) - { - if (fileDependencySet.contains(sourceFile)) - continue; - fileDependencySet.add(sourceFile); - m_fileDependencies.add(sourceFile); - } - - // Finalyl we also add the module for the specialization arguments - // to the list of modules that would be used for legacy lookup - // operations where we need an implicit/default scope to use - // and want it to be expansive. - // - // TODO: This stuff really isn't worth keeping around long - // term, and we should ditch the entire "legacy lookup" idea. - // - m_moduleDependencies.add(module); - } - - // Because we are specializing shader code, the mangled entry - // point names for this component type may be different than - // for the base component type (e.g., the mangled name for `f<int>` - // is different than that that of the generic `f` function - // itself). - // - // We will compute the mangled names of all the entry points and - // store them here, so that we don't have to do it on the fly. - // Because the `ComponentType` structure is hierarchical, we - // need to use a recursive visitor to compute the names, - // and we will define that visitor locally: - // - struct EntryPointMangledNameCollector : ComponentTypeVisitor - { - List<String>* mangledEntryPointNames; - List<String>* entryPointNameOverrides; - - void visitEntryPoint( - EntryPoint* entryPoint, - EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE - { - auto funcDeclRef = entryPoint->getFuncDeclRef(); - if (specializationInfo) - funcDeclRef = specializationInfo->specializedFuncDeclRef; - - (*mangledEntryPointNames).add(getMangledName(m_astBuilder, funcDeclRef)); - (*entryPointNameOverrides).add(entryPoint->getEntryPointNameOverride(0)); - } - - void visitRenamedEntryPoint( - RenamedEntryPointComponentType* entryPoint, - EntryPoint::EntryPointSpecializationInfo* specializationInfo) SLANG_OVERRIDE - { - entryPoint->getBase()->acceptVisitor(this, specializationInfo); - (*entryPointNameOverrides).getLast() = entryPoint->getEntryPointNameOverride(0); - } - - void visitModule(Module*, Module::ModuleSpecializationInfo*) SLANG_OVERRIDE {} - void visitComposite( - CompositeComponentType* composite, - CompositeComponentType::CompositeSpecializationInfo* specializationInfo) SLANG_OVERRIDE - { - visitChildren(composite, specializationInfo); - } - void visitSpecialized(SpecializedComponentType* specialized) SLANG_OVERRIDE - { - visitChildren(specialized); - } - void visitTypeConformance(TypeConformance* conformance) SLANG_OVERRIDE - { - SLANG_UNUSED(conformance); - } - EntryPointMangledNameCollector(ASTBuilder* astBuilder) - : m_astBuilder(astBuilder) - { - } - ASTBuilder* m_astBuilder; - }; - - // With the visitor defined, we apply it to ourself to compute - // and collect the mangled entry point names. - // - EntryPointMangledNameCollector collector(getLinkage()->getASTBuilder()); - collector.mangledEntryPointNames = &m_entryPointMangledNames; - collector.entryPointNameOverrides = &m_entryPointNameOverrides; - collector.visitSpecialized(this); -} - -void SpecializedComponentType::buildHash(DigestBuilder<SHA1>& builder) -{ - auto specializationArgCount = getSpecializationArgCount(); - for (Index i = 0; i < specializationArgCount; ++i) - { - auto specializationArg = getSpecializationArg(i); - auto argString = specializationArg.val->toString(); - builder.append(argString); - } - - getBaseComponentType()->buildHash(builder); -} - -void SpecializedComponentType::acceptVisitor( - ComponentTypeVisitor* visitor, - SpecializationInfo* specializationInfo) -{ - SLANG_ASSERT(specializationInfo == nullptr); - SLANG_UNUSED(specializationInfo); - visitor->visitSpecialized(this); -} - -Index SpecializedComponentType::getRequirementCount() -{ - return m_requirements.getCount(); -} - -RefPtr<ComponentType> SpecializedComponentType::getRequirement(Index index) -{ - return m_requirements[index]; -} - -String SpecializedComponentType::getEntryPointMangledName(Index index) -{ - return m_entryPointMangledNames[index]; -} - -String SpecializedComponentType::getEntryPointNameOverride(Index index) -{ - return m_entryPointNameOverrides[index]; -} - -// RenamedEntryPointComponentType - -RenamedEntryPointComponentType::RenamedEntryPointComponentType(ComponentType* base, String newName) - : ComponentType(base->getLinkage()), m_base(base), m_entryPointNameOverride(newName) -{ -} - -void RenamedEntryPointComponentType::acceptVisitor( - ComponentTypeVisitor* visitor, - SpecializationInfo* specializationInfo) -{ - visitor->visitRenamedEntryPoint( - this, - as<EntryPoint::EntryPointSpecializationInfo>(specializationInfo)); -} - -void RenamedEntryPointComponentType::buildHash(DigestBuilder<SHA1>& builder) -{ - SLANG_UNUSED(builder); -} - -void ComponentTypeVisitor::visitChildren( - CompositeComponentType* composite, - CompositeComponentType::CompositeSpecializationInfo* specializationInfo) -{ - auto childCount = composite->getChildComponentCount(); - for (Index ii = 0; ii < childCount; ++ii) - { - auto child = composite->getChildComponent(ii); - auto childSpecializationInfo = - specializationInfo ? specializationInfo->childInfos[ii] : nullptr; - - child->acceptVisitor(this, childSpecializationInfo); - } -} - -void ComponentTypeVisitor::visitChildren(SpecializedComponentType* specialized) -{ - specialized->getBaseComponentType()->acceptVisitor(this, specialized->getSpecializationInfo()); -} - -TargetProgram* ComponentType::getTargetProgram(TargetRequest* target) -{ - RefPtr<TargetProgram> targetProgram; - if (!m_targetPrograms.tryGetValue(target, targetProgram)) - { - targetProgram = new TargetProgram(this, target); - m_targetPrograms[target] = targetProgram; - } - return targetProgram; -} - -// -// TargetProgram -// - -TargetProgram::TargetProgram(ComponentType* componentType, TargetRequest* targetReq) - : m_program(componentType), m_targetReq(targetReq) -{ - m_entryPointResults.setCount(componentType->getEntryPointCount()); - m_optionSet.overrideWith(m_program->getOptionSet()); - m_optionSet.inheritFrom(targetReq->getOptionSet()); -} - -// - -Session* CompileRequestBase::getSession() -{ - return getLinkage()->getSessionImpl(); -} - -void Linkage::setFileSystem(ISlangFileSystem* inFileSystem) -{ - // Set the fileSystem - m_fileSystem = inFileSystem; - - // Release what's there - m_fileSystemExt.setNull(); - - // If nullptr passed in set up default - if (inFileSystem == nullptr) - { - m_fileSystemExt = new Slang::CacheFileSystem(Slang::OSFileSystem::getExtSingleton()); - } - else - { - if (auto cacheFileSystem = as<CacheFileSystem>(inFileSystem)) - { - m_fileSystemExt = cacheFileSystem; - } - else - { - if (m_requireCacheFileSystem) - { - m_fileSystemExt = new Slang::CacheFileSystem(inFileSystem); - } - else - { - // See if we have the full ISlangFileSystemExt interface, if we do just use it - inFileSystem->queryInterface(SLANG_IID_PPV_ARGS(m_fileSystemExt.writeRef())); - - // If not wrap with CacheFileSystem that emulates ISlangFileSystemExt from the - // ISlangFileSystem interface - if (!m_fileSystemExt) - { - // Construct a wrapper to emulate the extended interface behavior - m_fileSystemExt = new Slang::CacheFileSystem(m_fileSystem); - } - } - } - } - - // If requires a cache file system, check that it does have one - SLANG_ASSERT(m_requireCacheFileSystem == false || as<CacheFileSystem>(m_fileSystemExt)); - - // Set the file system used on the source manager - getSourceManager()->setFileSystemExt(m_fileSystemExt); -} - -SlangResult Linkage::loadSerializedModuleContents( - Module* module, - const PathInfo& moduleFilePathInfo, - ISlangBlob* blobHoldingSerializedData, - ModuleChunk const* moduleChunk, - RIFF::ListChunk const* containerChunk, - DiagnosticSink* sink) -{ - // At this point we've dealt with basically all of - // the formalities, and we just need to get down - // to the real work of decoding the information - // in the `moduleChunk`. - - // - // TODO(tfoley): The fact that a separate `containerChunk` is getting - // passed in here is entirely byproduct of the support for "module libraries" - // that can (in principle) contain multiple serialized modules. When - // things are serialized in the "container" representation used for - // a module library, there is a single `DebugChunk` as a child of - // the container, with all of the `ModuleChunk`s sharing that debug info. - // - // In contrast, the more typical kind of serialized module that the compiler - // produces serializes a single `ModuleChunk`, and the `DebugChunk` is - // one of its direct children. Thus there are currently two different - // locations where debug information might be found. - // - // Prior to the change where we navigate the serialized RIFF hierarchy - // in memory without copying it, this issue was addressed by having - // the subroutine that looked for a `DebugChunk` start at the `ModuleChunk` - // and work its way up through the hierarchy using parent pointers that - // were created as part of RIFF loading. When navigating the RIFF in-place - // we don't have such parent pointers. - // - // As a short-term solution, we should deprecate and remove the support - // for "module libraries" so that the code doesn't have to handle two - // different layouts. - // - // In the longer term, we should be making some conscious design decisions - // around how we want to organize the top-level structure of our serialized - // intermediate/output formats, since there's quite a mix of different - // approaches currently in use. - // - - auto sourceManager = getSourceManager(); - RefPtr<SerialSourceLocReader> sourceLocReader; - if (auto debugChunk = DebugChunk::find(moduleChunk, containerChunk)) - { - SLANG_RETURN_ON_FAIL( - readSourceLocationsFromDebugChunk(debugChunk, sourceManager, sourceLocReader)); - } - - auto astChunk = moduleChunk->findAST(); - if (!astChunk) - return SLANG_FAIL; - - auto irChunk = moduleChunk->findIR(); - if (!irChunk) - return SLANG_FAIL; - - auto astBuilder = getASTBuilder(); - auto session = getSessionImpl(); - - // For the purposes of any modules referenced - // by the module we're about to decode, we will - // construct a source location that represents - // the module itself (if possible). - // - // TODO(tfoley): This logic seems like overkill, given - // that many (most? all?) control-flow paths that can - // reach this routine will have already found a `SourceFile` - // to represent the module, as part of even getting the - // `moduleFilePathInfo` to pass in - // - // The approach here is more or less exactly copied - // from what the old `SerialContainerUtil::read` function - // used to do, with the hopes that it will as many tests - // passing as possible. - // - // Down the line somebody should scrutinize all of this - // kind of logic in the compiler codebase, because there - // is something that feels unclean about how paths are being handled. - // - SourceLoc serializedModuleLoc; - { - auto sourceFile = - sourceManager->findSourceFileByPathRecursively(moduleFilePathInfo.foundPath); - if (!sourceFile) - { - sourceFile = sourceManager->createSourceFileWithString(moduleFilePathInfo, String()); - sourceManager->addSourceFile(moduleFilePathInfo.getMostUniqueIdentity(), sourceFile); - } - auto sourceView = - sourceManager->createSourceView(sourceFile, &moduleFilePathInfo, SourceLoc()); - serializedModuleLoc = sourceView->getRange().begin; - } - - auto moduleDecl = readSerializedModuleAST( - this, - astBuilder, - sink, - blobHoldingSerializedData, - astChunk, - sourceLocReader, - serializedModuleLoc); - if (!moduleDecl) - return SLANG_FAIL; - module->setModuleDecl(moduleDecl); - - RefPtr<IRModule> irModule; - SLANG_RETURN_ON_FAIL(readSerializedModuleIR(irChunk, session, sourceLocReader, irModule)); - module->setIRModule(irModule); - - // The handling of file dependencies is complicated, because of - // the way that the encoding logic tried to make all of the - // paths be relative to the primary source file for the module. - // - // We end up needing to undo some amount of that work here. - // - - module->clearFileDependency(); - String moduleSourcePath = moduleFilePathInfo.foundPath; - bool isFirst = true; - for (auto depenencyFileChunk : moduleChunk->getFileDependencies()) - { - auto encodedDependencyFilePath = depenencyFileChunk->getValue(); - - auto sourceFile = loadSourceFile(moduleFilePathInfo.foundPath, encodedDependencyFilePath); - if (isFirst) - { - // The first file is the source for the main module file. - // We store the module path as the basis for finding the remaining - // dependent files. - if (sourceFile) - moduleSourcePath = sourceFile->getPathInfo().foundPath; - isFirst = false; - } - // If we cannot find the dependent file directly, try to find - // it relative to the module source path. - if (!sourceFile) - { - sourceFile = loadSourceFile(moduleSourcePath, encodedDependencyFilePath); - } - if (sourceFile) - { - module->addFileDependency(sourceFile); - } - } - module->setPathInfo(moduleFilePathInfo); - module->setDigest(moduleChunk->getDigest()); - module->_collectShaderParams(); - module->_discoverEntryPoints(sink, targets); - - // Hook up fileDecl's scope to module's scope. - for (auto fileDecl : moduleDecl->getDirectMemberDeclsOfType<FileDecl>()) - { - addSiblingScopeForContainerDecl(m_astBuilder, moduleDecl->ownedScope, fileDecl); - } - - return SLANG_OK; -} - -void Linkage::setRequireCacheFileSystem(bool requireCacheFileSystem) -{ - if (requireCacheFileSystem == m_requireCacheFileSystem) - { - return; - } - - ComPtr<ISlangFileSystem> scopeFileSystem(m_fileSystem); - m_requireCacheFileSystem = requireCacheFileSystem; - - setFileSystem(scopeFileSystem); -} - -RefPtr<Module> findOrImportModule( - Linkage* linkage, - Name* name, - SourceLoc const& loc, - DiagnosticSink* sink, - const LoadedModuleDictionary* loadedModules) -{ - return linkage->findOrImportModule(name, loc, sink, loadedModules); -} - -void Session::addBuiltinSource( - Scope* scope, - String const& path, - ISlangBlob* sourceBlob, - Module*& outModule) -{ - SourceManager* sourceManager = getBuiltinSourceManager(); - - DiagnosticSink sink(sourceManager, Lexer::sourceLocationLexer); - - RefPtr<FrontEndCompileRequest> compileRequest = - new FrontEndCompileRequest(m_builtinLinkage, nullptr, &sink); - compileRequest->m_isCoreModuleCode = true; - - // Set the source manager on the sink - sink.setSourceManager(sourceManager); - // Make the linkage use the builtin source manager - Linkage* linkage = compileRequest->getLinkage(); - linkage->setSourceManager(sourceManager); - - Name* moduleName = getNamePool()->getName(path); - auto translationUnitIndex = - compileRequest->addTranslationUnit(SourceLanguage::Slang, moduleName); - - compileRequest->addTranslationUnitSourceBlob(translationUnitIndex, path, sourceBlob); - - SlangResult res = compileRequest->executeActionsInner(); - if (SLANG_FAILED(res)) - { - char const* diagnostics = sink.outputBuffer.getBuffer(); - fprintf(stderr, "%s", diagnostics); - - PlatformUtil::outputDebugMessage(diagnostics); - - SLANG_UNEXPECTED("error in Slang core module"); - } - - // Compiling the core module should not yield any warnings. - SLANG_ASSERT(sink.outputBuffer.getLength() == 0); - - // Extract the AST for the code we just parsed - auto module = compileRequest->translationUnits[translationUnitIndex]->getModule(); - auto moduleDecl = module->getModuleDecl(); - - // Extact documentation markup. - ASTMarkup markup; - ASTMarkupUtil::extract(moduleDecl, sourceManager, &sink, &markup); - markup.attachToAST(); - - // Put in the loaded module map - linkage->mapNameToLoadedModules.add(moduleName, module); - - // Add the resulting code to the appropriate scope - if (!scope->containerDecl) - { - // We are the first chunk of code to be loaded for this scope - scope->containerDecl = moduleDecl; - } - else - { - // We need to create a new scope to link into the whole thing - auto subScope = module->getASTBuilder()->create<Scope>(); - subScope->containerDecl = moduleDecl; - subScope->nextSibling = scope->nextSibling; - scope->nextSibling = subScope; - } - - outModule = module; -} - -Session::~Session() -{ - // This is necessary because this ASTBuilder uses the SharedASTBuilder also owned by the - // session. If the SharedASTBuilder gets dtored before the globalASTBuilder it has a dangling - // pointer, which is referenced in the ASTBuilder dtor (likely) causing a crash. - // - // By destroying first we know it is destroyed, before the SharedASTBuilder. - globalAstBuilder.setNull(); - - // destroy modules next - coreModules = decltype(coreModules)(); -} - -} // namespace Slang - - -/* !!!!!!!!!!!!!!!!!! EndToEndCompileRequestImpl !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -namespace Slang -{ - -void EndToEndCompileRequest::setFileSystem(ISlangFileSystem* fileSystem) -{ - getLinkage()->setFileSystem(fileSystem); -} - -void EndToEndCompileRequest::setCompileFlags(SlangCompileFlags flags) -{ - if (flags & SLANG_COMPILE_FLAG_NO_MANGLING) - getOptionSet().set(CompilerOptionName::NoMangle, true); - if (flags & SLANG_COMPILE_FLAG_NO_CODEGEN) - getOptionSet().set(CompilerOptionName::SkipCodeGen, true); - if (flags & SLANG_COMPILE_FLAG_OBFUSCATE) - getOptionSet().set(CompilerOptionName::Obfuscate, true); -} - -SlangCompileFlags EndToEndCompileRequest::getCompileFlags() -{ - SlangCompileFlags result = 0; - if (getOptionSet().getBoolOption(CompilerOptionName::NoMangle)) - result |= SLANG_COMPILE_FLAG_NO_MANGLING; - if (getOptionSet().getBoolOption(CompilerOptionName::SkipCodeGen)) - result |= SLANG_COMPILE_FLAG_NO_CODEGEN; - if (getOptionSet().getBoolOption(CompilerOptionName::Obfuscate)) - result |= SLANG_COMPILE_FLAG_OBFUSCATE; - return result; -} - -void EndToEndCompileRequest::setDumpIntermediates(int enable) -{ - getOptionSet().set(CompilerOptionName::DumpIntermediates, enable); -} - -void EndToEndCompileRequest::setTrackLiveness(bool v) -{ - getOptionSet().set(CompilerOptionName::TrackLiveness, v); -} - -void EndToEndCompileRequest::setDumpIntermediatePrefix(const char* prefix) -{ - getOptionSet().set(CompilerOptionName::DumpIntermediatePrefix, String(prefix)); -} - -void EndToEndCompileRequest::setLineDirectiveMode(SlangLineDirectiveMode mode) -{ - getOptionSet().set(CompilerOptionName::LineDirectiveMode, mode); -} - -void EndToEndCompileRequest::setCommandLineCompilerMode() -{ - m_isCommandLineCompile = true; - - // legacy slangc tool defaults to column major layout. - if (!getOptionSet().hasOption(CompilerOptionName::MatrixLayoutRow)) - getOptionSet().setMatrixLayoutMode(kMatrixLayoutMode_ColumnMajor); -} - -void EndToEndCompileRequest::_completeTargetRequest(UInt targetIndex) -{ - auto linkage = getLinkage(); - - TargetRequest* targetRequest = linkage->targets[Index(targetIndex)]; - - targetRequest->getOptionSet().inheritFrom(getLinkage()->m_optionSet); - targetRequest->getOptionSet().inheritFrom(m_optionSetForDefaultTarget); -} - -void EndToEndCompileRequest::setCodeGenTarget(SlangCompileTarget target) -{ - auto linkage = getLinkage(); - linkage->targets.clear(); - const auto targetIndex = linkage->addTarget(CodeGenTarget(target)); - SLANG_ASSERT(targetIndex == 0); - _completeTargetRequest(0); -} - -int EndToEndCompileRequest::addCodeGenTarget(SlangCompileTarget target) -{ - const auto targetIndex = getLinkage()->addTarget(CodeGenTarget(target)); - _completeTargetRequest(targetIndex); - return int(targetIndex); -} - -void EndToEndCompileRequest::setTargetProfile(int targetIndex, SlangProfileID profile) -{ - getTargetOptionSet(targetIndex).setProfile(Profile(profile)); -} - -void EndToEndCompileRequest::setTargetFlags(int targetIndex, SlangTargetFlags flags) -{ - getTargetOptionSet(targetIndex).setTargetFlags(flags); -} - -void EndToEndCompileRequest::setTargetForceGLSLScalarBufferLayout(int targetIndex, bool value) -{ - getTargetOptionSet(targetIndex).set(CompilerOptionName::GLSLForceScalarLayout, value); -} - -void EndToEndCompileRequest::setTargetForceDXLayout(int targetIndex, bool value) -{ - getTargetOptionSet(targetIndex).set(CompilerOptionName::ForceDXLayout, value); -} - -void EndToEndCompileRequest::setTargetFloatingPointMode( - int targetIndex, - SlangFloatingPointMode mode) -{ - getTargetOptionSet(targetIndex) - .set(CompilerOptionName::FloatingPointMode, FloatingPointMode(mode)); -} - -void EndToEndCompileRequest::setMatrixLayoutMode(SlangMatrixLayoutMode mode) -{ - getOptionSet().setMatrixLayoutMode((MatrixLayoutMode)mode); -} - -void EndToEndCompileRequest::setTargetMatrixLayoutMode(int targetIndex, SlangMatrixLayoutMode mode) -{ - getTargetOptionSet(targetIndex).setMatrixLayoutMode(MatrixLayoutMode(mode)); -} - -void EndToEndCompileRequest::setTargetGenerateWholeProgram(int targetIndex, bool value) -{ - getTargetOptionSet(targetIndex).set(CompilerOptionName::GenerateWholeProgram, value); -} - -void EndToEndCompileRequest::setTargetEmbedDownstreamIR(int targetIndex, bool value) -{ - getTargetOptionSet(targetIndex).set(CompilerOptionName::EmbedDownstreamIR, value); -} - -void EndToEndCompileRequest::setTargetLineDirectiveMode( - SlangInt targetIndex, - SlangLineDirectiveMode mode) -{ - getTargetOptionSet(targetIndex) - .set(CompilerOptionName::LineDirectiveMode, LineDirectiveMode(mode)); -} - -void EndToEndCompileRequest::overrideDiagnosticSeverity( - SlangInt messageID, - SlangSeverity overrideSeverity) -{ - getSink()->overrideDiagnosticSeverity(int(messageID), Severity(overrideSeverity)); -} - -SlangDiagnosticFlags EndToEndCompileRequest::getDiagnosticFlags() -{ - DiagnosticSink::Flags sinkFlags = getSink()->getFlags(); - - SlangDiagnosticFlags flags = 0; - - if (sinkFlags & DiagnosticSink::Flag::VerbosePath) - flags |= SLANG_DIAGNOSTIC_FLAG_VERBOSE_PATHS; - - if (sinkFlags & DiagnosticSink::Flag::TreatWarningsAsErrors) - flags |= SLANG_DIAGNOSTIC_FLAG_TREAT_WARNINGS_AS_ERRORS; - - return flags; -} - -void EndToEndCompileRequest::setDiagnosticFlags(SlangDiagnosticFlags flags) -{ - DiagnosticSink::Flags sinkFlags = getSink()->getFlags(); - - if (flags & SLANG_DIAGNOSTIC_FLAG_VERBOSE_PATHS) - sinkFlags |= DiagnosticSink::Flag::VerbosePath; - else - sinkFlags &= ~DiagnosticSink::Flag::VerbosePath; - - if (flags & SLANG_DIAGNOSTIC_FLAG_TREAT_WARNINGS_AS_ERRORS) - sinkFlags |= DiagnosticSink::Flag::TreatWarningsAsErrors; - else - sinkFlags &= ~DiagnosticSink::Flag::TreatWarningsAsErrors; - - getSink()->setFlags(sinkFlags); -} - -SlangResult EndToEndCompileRequest::addTargetCapability( - SlangInt targetIndex, - SlangCapabilityID capability) -{ - auto& targets = getLinkage()->targets; - if (targetIndex < 0 || targetIndex >= targets.getCount()) - return SLANG_E_INVALID_ARG; - getTargetOptionSet(targetIndex).addCapabilityAtom(CapabilityName(capability)); - return SLANG_OK; -} - -void EndToEndCompileRequest::setDebugInfoLevel(SlangDebugInfoLevel level) -{ - getOptionSet().set(CompilerOptionName::DebugInformation, DebugInfoLevel(level)); -} - -void EndToEndCompileRequest::setDebugInfoFormat(SlangDebugInfoFormat format) -{ - getOptionSet().set(CompilerOptionName::DebugInformationFormat, DebugInfoFormat(format)); -} - -void EndToEndCompileRequest::setOptimizationLevel(SlangOptimizationLevel level) -{ - getOptionSet().set(CompilerOptionName::Optimization, OptimizationLevel(level)); -} - -void EndToEndCompileRequest::setOutputContainerFormat(SlangContainerFormat format) -{ - m_containerFormat = ContainerFormat(format); -} - -void EndToEndCompileRequest::setPassThrough(SlangPassThrough inPassThrough) -{ - m_passThrough = PassThroughMode(inPassThrough); -} - -void EndToEndCompileRequest::setReportDownstreamTime(bool value) -{ - getOptionSet().set(CompilerOptionName::ReportDownstreamTime, value); -} - -void EndToEndCompileRequest::setReportPerfBenchmark(bool value) -{ - getOptionSet().set(CompilerOptionName::ReportPerfBenchmark, value); -} - -void EndToEndCompileRequest::setSkipSPIRVValidation(bool value) -{ - getOptionSet().set(CompilerOptionName::SkipSPIRVValidation, value); -} - -void EndToEndCompileRequest::setTargetUseMinimumSlangOptimization(int targetIndex, bool value) -{ - getTargetOptionSet(targetIndex).set(CompilerOptionName::MinimumSlangOptimization, value); -} - -void EndToEndCompileRequest::setIgnoreCapabilityCheck(bool value) -{ - getOptionSet().set(CompilerOptionName::IgnoreCapabilities, value); -} - -void EndToEndCompileRequest::setDiagnosticCallback( - SlangDiagnosticCallback callback, - void const* userData) -{ - ComPtr<ISlangWriter> writer(new CallbackWriter(callback, userData, WriterFlag::IsConsole)); - setWriter(WriterChannel::Diagnostic, writer); -} - -void EndToEndCompileRequest::setWriter(SlangWriterChannel chan, ISlangWriter* writer) -{ - setWriter(WriterChannel(chan), writer); -} - -ISlangWriter* EndToEndCompileRequest::getWriter(SlangWriterChannel chan) -{ - return getWriter(WriterChannel(chan)); -} - -void EndToEndCompileRequest::addSearchPath(const char* path) -{ - getOptionSet().addSearchPath(path); -} - -void EndToEndCompileRequest::addPreprocessorDefine(const char* key, const char* value) -{ - getOptionSet().addPreprocessorDefine(key, value); -} - -void EndToEndCompileRequest::setEnableEffectAnnotations(bool value) -{ - getOptionSet().set(CompilerOptionName::EnableEffectAnnotations, value); -} - -char const* EndToEndCompileRequest::getDiagnosticOutput() -{ - return m_diagnosticOutput.begin(); -} - -SlangResult EndToEndCompileRequest::getDiagnosticOutputBlob(ISlangBlob** outBlob) -{ - if (!outBlob) - return SLANG_E_INVALID_ARG; - - if (!m_diagnosticOutputBlob) - { - m_diagnosticOutputBlob = StringUtil::createStringBlob(m_diagnosticOutput); - } - - ComPtr<ISlangBlob> resultBlob = m_diagnosticOutputBlob; - *outBlob = resultBlob.detach(); - return SLANG_OK; -} - -int EndToEndCompileRequest::addTranslationUnit(SlangSourceLanguage language, char const* inName) -{ - auto frontEndReq = getFrontEndReq(); - NamePool* namePool = frontEndReq->getNamePool(); - - // Work out a module name. Can be nullptr if so will generate a name - Name* moduleName = inName ? namePool->getName(inName) : frontEndReq->m_defaultModuleName; - - // If moduleName is nullptr a name will be generated - return frontEndReq->addTranslationUnit(Slang::SourceLanguage(language), moduleName); -} - -void EndToEndCompileRequest::setDefaultModuleName(const char* defaultModuleName) -{ - auto frontEndReq = getFrontEndReq(); - NamePool* namePool = frontEndReq->getNamePool(); - frontEndReq->m_defaultModuleName = namePool->getName(defaultModuleName); -} - -SlangResult _addLibraryReference( - EndToEndCompileRequest* req, - ModuleLibrary* moduleLibrary, - bool includeEntryPoint) -{ - FrontEndCompileRequest* frontEndRequest = req->getFrontEndReq(); - - if (includeEntryPoint) - { - frontEndRequest->m_extraEntryPoints.addRange( - moduleLibrary->m_entryPoints.getBuffer(), - moduleLibrary->m_entryPoints.getCount()); - } - - for (auto m : moduleLibrary->m_modules) - { - RefPtr<TranslationUnitRequest> tu = new TranslationUnitRequest(frontEndRequest, m); - frontEndRequest->translationUnits.add(tu); - // For modules loaded for EndToEndCompileRequest, - // we don't need the automatically discovered entrypoints. - if (!includeEntryPoint) - m->getEntryPoints().clear(); - } - return SLANG_OK; -} - -SlangResult _addLibraryReference( - EndToEndCompileRequest* req, - String path, - IArtifact* artifact, - bool includeEntryPoint) -{ - auto desc = artifact->getDesc(); - - // TODO(JS): - // This isn't perhaps the best way to handle this scenario, as IArtifact can - // support lazy evaluation, with suitable hander. - // For now we just read in and strip out the bits we want. - if (isDerivedFrom(desc.kind, ArtifactKind::Container) && - isDerivedFrom(desc.payload, ArtifactPayload::CompileResults)) - { - // We want to read as a file system - ComPtr<IArtifact> container; - - SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::readContainer(artifact, container)); - - // Find the payload... It should be linkable - if (!ArtifactDescUtil::isLinkable(container->getDesc())) - { - return SLANG_FAIL; - } - - ComPtr<IModuleLibrary> libraryIntf; - SLANG_RETURN_ON_FAIL( - loadModuleLibrary(ArtifactKeep::Yes, container, path, req, libraryIntf)); - - auto library = as<ModuleLibrary>(libraryIntf); - - // Look for source maps - for (auto associated : container->getAssociated()) - { - auto assocDesc = associated->getDesc(); - - // If we find an obfuscated source map load it and associate - if (isDerivedFrom(assocDesc.kind, ArtifactKind::Json) && - isDerivedFrom(assocDesc.payload, ArtifactPayload::SourceMap) && - isDerivedFrom(assocDesc.style, ArtifactStyle::Obfuscated)) - { - ComPtr<ICastable> castable; - SLANG_RETURN_ON_FAIL(associated->getOrCreateRepresentation( - SourceMap::getTypeGuid(), - ArtifactKeep::Yes, - castable.writeRef())); - auto sourceMapBox = asBoxValue<SourceMap>(castable); - SLANG_ASSERT(sourceMapBox); - - // TODO(JS): - // There is perhaps (?) a risk here that we might copy the obfuscated map - // into some output container. Currently that only happens for source maps - // that are from translation units. - // - // On the other hand using "import" is a way that such source maps *would* be - // copied into the output, and that is something that could be a vector - // for leaking. - // - // That isn't a risk from -r though because, it doesn't create a translation - // unit(s). - for (auto module : library->m_modules) - { - module->getIRModule()->setObfuscatedSourceMap(sourceMapBox); - } - - // Look up the source file - auto sourceManager = req->getSink()->getSourceManager(); - - auto name = Path::getFileNameWithoutExt(associated->getName()); - - if (name.getLength()) - { - // Note(tfoley): There is a subtle requirement here, that any - // source file `name` that might be searched for here *must* - // have been added to the `sourceManager` already, as a - // byproduct of debug source location information getting - // deserialized as part of the call to `loadModuleLibrary()` above. - // - // The implicit dependency is frustrating, and could potentially - // break if somehow the debug info chunk was stripped from a binary, - // while the source map was left in (which should be valid, even if - // it is unlikely to be what a user wants). - // - // Ideally the source map would either be made an integral part of - // the debug source location chunk, so they are loaded together, - // or the `SourceManager` would be adapted so that it can store - // registered source maps independent of whether or not the - // corresponding source file(s) have been loaded. - - auto sourceFile = sourceManager->findSourceFileByPathRecursively(name); - SLANG_ASSERT(sourceFile); - sourceFile->setSourceMap(sourceMapBox, SourceMapKind::Obfuscated); - } - } - } - - SLANG_RETURN_ON_FAIL(_addLibraryReference(req, library, includeEntryPoint)); - return SLANG_OK; - } - - if (desc.kind == ArtifactKind::Library && desc.payload == ArtifactPayload::SlangIR) - { - ComPtr<IModuleLibrary> libraryIntf; - - SLANG_RETURN_ON_FAIL( - loadModuleLibrary(ArtifactKeep::Yes, artifact, path, req, libraryIntf)); - - auto library = as<ModuleLibrary>(libraryIntf); - if (!library) - { - return SLANG_FAIL; - } - - SLANG_RETURN_ON_FAIL(_addLibraryReference(req, library, includeEntryPoint)); - return SLANG_OK; - } - - // TODO(JS): - // Do we want to check the path exists? - - // Add to the m_libModules - auto linkage = req->getLinkage(); - linkage->m_libModules.add(ComPtr<IArtifact>(artifact)); - - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::addLibraryReference( - const char* basePath, - const void* libData, - size_t libDataSize) -{ - // We need to deserialize and add the modules - ComPtr<IModuleLibrary> library; - - auto libBlob = RawBlob::create((const Byte*)libData, libDataSize); - - SLANG_RETURN_ON_FAIL( - loadModuleLibrary(libBlob, (const Byte*)libData, libDataSize, basePath, this, library)); - - // Create an artifact without any name (as one is not provided) - auto artifact = - Artifact::create(ArtifactDesc::make(ArtifactKind::Library, ArtifactPayload::SlangIR)); - artifact->addRepresentation(library); - - return _addLibraryReference(this, basePath, artifact, true); -} - -void EndToEndCompileRequest::addTranslationUnitPreprocessorDefine( - int translationUnitIndex, - const char* key, - const char* value) -{ - getFrontEndReq()->translationUnits[translationUnitIndex]->preprocessorDefinitions[key] = value; -} - -void EndToEndCompileRequest::addTranslationUnitSourceFile( - int translationUnitIndex, - char const* path) -{ - auto frontEndReq = getFrontEndReq(); - if (!path) - return; - if (translationUnitIndex < 0) - return; - if (Index(translationUnitIndex) >= frontEndReq->translationUnits.getCount()) - return; - - frontEndReq->addTranslationUnitSourceFile(translationUnitIndex, path); -} - -void EndToEndCompileRequest::addTranslationUnitSourceString( - int translationUnitIndex, - char const* path, - char const* source) -{ - if (!source) - return; - addTranslationUnitSourceStringSpan(translationUnitIndex, path, source, source + strlen(source)); -} - -void EndToEndCompileRequest::addTranslationUnitSourceStringSpan( - int translationUnitIndex, - char const* path, - char const* sourceBegin, - char const* sourceEnd) -{ - auto frontEndReq = getFrontEndReq(); - if (!sourceBegin) - return; - if (translationUnitIndex < 0) - return; - if (Index(translationUnitIndex) >= frontEndReq->translationUnits.getCount()) - return; - - if (!path) - path = ""; - - const auto slice = UnownedStringSlice(sourceBegin, sourceEnd); - - auto blob = RawBlob::create(slice.begin(), slice.getLength()); - - frontEndReq->addTranslationUnitSourceBlob(translationUnitIndex, path, blob); -} - -void EndToEndCompileRequest::addTranslationUnitSourceBlob( - int translationUnitIndex, - char const* path, - ISlangBlob* sourceBlob) -{ - auto frontEndReq = getFrontEndReq(); - if (!sourceBlob) - return; - if (translationUnitIndex < 0) - return; - if (Slang::Index(translationUnitIndex) >= frontEndReq->translationUnits.getCount()) - return; - - if (!path) - path = ""; - - frontEndReq->addTranslationUnitSourceBlob(translationUnitIndex, path, sourceBlob); -} - - -int EndToEndCompileRequest::addEntryPoint( - int translationUnitIndex, - char const* name, - SlangStage stage) -{ - return addEntryPointEx(translationUnitIndex, name, stage, 0, nullptr); -} - -int EndToEndCompileRequest::addEntryPointEx( - int translationUnitIndex, - char const* name, - SlangStage stage, - int genericParamTypeNameCount, - char const** genericParamTypeNames) -{ - auto frontEndReq = getFrontEndReq(); - if (!name) - return -1; - if (translationUnitIndex < 0) - return -1; - if (Index(translationUnitIndex) >= frontEndReq->translationUnits.getCount()) - return -1; - - List<String> typeNames; - for (int i = 0; i < genericParamTypeNameCount; i++) - typeNames.add(genericParamTypeNames[i]); - - return addEntryPoint(translationUnitIndex, name, Profile(Stage(stage)), typeNames); -} - -SlangResult EndToEndCompileRequest::setGlobalGenericArgs( - int genericArgCount, - char const** genericArgs) -{ - auto& argStrings = m_globalSpecializationArgStrings; - argStrings.clear(); - for (int i = 0; i < genericArgCount; i++) - argStrings.add(genericArgs[i]); - - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::setTypeNameForGlobalExistentialTypeParam( - int slotIndex, - char const* typeName) -{ - if (slotIndex < 0) - return SLANG_FAIL; - if (!typeName) - return SLANG_FAIL; - - auto& typeArgStrings = m_globalSpecializationArgStrings; - if (Index(slotIndex) >= typeArgStrings.getCount()) - typeArgStrings.setCount(slotIndex + 1); - typeArgStrings[slotIndex] = String(typeName); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::setTypeNameForEntryPointExistentialTypeParam( - int entryPointIndex, - int slotIndex, - char const* typeName) -{ - if (entryPointIndex < 0) - return SLANG_FAIL; - if (slotIndex < 0) - return SLANG_FAIL; - if (!typeName) - return SLANG_FAIL; - - if (Index(entryPointIndex) >= m_entryPoints.getCount()) - return SLANG_FAIL; - - auto& entryPointInfo = m_entryPoints[entryPointIndex]; - auto& typeArgStrings = entryPointInfo.specializationArgStrings; - if (Index(slotIndex) >= typeArgStrings.getCount()) - typeArgStrings.setCount(slotIndex + 1); - typeArgStrings[slotIndex] = String(typeName); - return SLANG_OK; -} - -void EndToEndCompileRequest::setAllowGLSLInput(bool value) -{ - getOptionSet().set(CompilerOptionName::AllowGLSL, value); -} - -SlangResult EndToEndCompileRequest::compile() -{ - SlangResult res = SLANG_FAIL; - double downstreamStartTime = 0.0; - double totalStartTime = 0.0; - - if (getOptionSet().getBoolOption(CompilerOptionName::ReportDownstreamTime)) - { - getSession()->getCompilerElapsedTime(&totalStartTime, &downstreamStartTime); - PerformanceProfiler::getProfiler()->clear(); - } -#if !defined(SLANG_DEBUG_INTERNAL_ERROR) - // By default we'd like to catch as many internal errors as possible, - // and report them to the user nicely (rather than just crash their - // application). Internally Slang currently uses exceptions for this. - // - // TODO: Consider using `setjmp()`-style escape so that we can work - // with applications that disable exceptions. - // - // TODO: Consider supporting Windows "Structured Exception Handling" - // so that we can also recover from a wider class of crashes. - - try - { - SLANG_PROFILE_SECTION(compileInner); - res = executeActions(); - } - catch (const AbortCompilationException& e) - { - // This situation indicates a fatal (but not necessarily internal) error - // that forced compilation to terminate. There should already have been - // a diagnostic produced, so we don't need to add one here. - if (getSink()->getErrorCount() == 0) - { - // If for some reason we didn't output any diagnostic, something is - // going wrong, but we want to make sure we at least output something. - getSink()->diagnose( - SourceLoc(), - Diagnostics::compilationAbortedDueToException, - typeid(e).name(), - e.Message); - } - } - catch (const Exception& e) - { - // The compiler failed due to an internal error that was detected. - // We will print out information on the exception to help out the user - // in either filing a bug, or locating what in their code created - // a problem. - getSink()->diagnose( - SourceLoc(), - Diagnostics::compilationAbortedDueToException, - typeid(e).name(), - e.Message); - } - catch (...) - { - // The compiler failed due to some exception that wasn't a sublass of - // `Exception`, so something really fishy is going on. We want to - // let the user know that we messed up, so they know to blame Slang - // and not some other component in their system. - getSink()->diagnose(SourceLoc(), Diagnostics::compilationAborted); - } - m_diagnosticOutput = getSink()->outputBuffer.produceString(); - -#else - // When debugging, we probably don't want to filter out any errors, since - // we are probably trying to root-cause and *fix* those errors. - { - res = req->executeActions(); - } -#endif - - if (getOptionSet().getBoolOption(CompilerOptionName::ReportDownstreamTime)) - { - double downstreamEndTime = 0; - double totalEndTime = 0; - getSession()->getCompilerElapsedTime(&totalEndTime, &downstreamEndTime); - double downstreamTime = downstreamEndTime - downstreamStartTime; - String downstreamTimeStr = String(downstreamTime, "%.2f"); - getSink()->diagnose(SourceLoc(), Diagnostics::downstreamCompileTime, downstreamTimeStr); - } - if (getOptionSet().getBoolOption(CompilerOptionName::ReportPerfBenchmark)) - { - StringBuilder perfResult; - PerformanceProfiler::getProfiler()->getResult(perfResult); - perfResult << "\nType Dictionary Size: " << getSession()->m_typeDictionarySize << "\n"; - getSink()->diagnose( - SourceLoc(), - Diagnostics::performanceBenchmarkResult, - perfResult.produceString()); - } - - // Repro dump handling - { - auto dumpRepro = getOptionSet().getStringOption(CompilerOptionName::DumpRepro); - auto dumpReproOnError = getOptionSet().getBoolOption(CompilerOptionName::DumpReproOnError); - - if (dumpRepro.getLength()) - { - SlangResult saveRes = ReproUtil::saveState(this, dumpRepro); - if (SLANG_FAILED(saveRes)) - { - getSink()->diagnose(SourceLoc(), Diagnostics::unableToWriteReproFile, dumpRepro); - return saveRes; - } - } - else if (dumpReproOnError && SLANG_FAILED(res)) - { - String reproFileName; - SlangResult saveRes = SLANG_FAIL; - - RefPtr<Stream> stream; - if (SLANG_SUCCEEDED(ReproUtil::findUniqueReproDumpStream(this, reproFileName, stream))) - { - saveRes = ReproUtil::saveState(this, stream); - } - - if (SLANG_FAILED(saveRes)) - { - getSink()->diagnose( - SourceLoc(), - Diagnostics::unableToWriteReproFile, - reproFileName); - } - } - } - - auto reflectionPath = getOptionSet().getStringOption(CompilerOptionName::EmitReflectionJSON); - if (reflectionPath.getLength() != 0) - { - auto bufferWriter = PrettyWriter(); - emitReflectionJSON(this, this->getReflection(), bufferWriter); - if (reflectionPath == "-") - { - auto builder = bufferWriter.getBuilder(); - StdWriters::getOut().write(builder.getBuffer(), builder.getLength()); - } - else if (SLANG_FAILED(File::writeAllText(reflectionPath, bufferWriter.getBuilder()))) - { - getSink()->diagnose(SourceLoc(), Diagnostics::unableToWriteFile, reflectionPath); - } - } - - return res; -} - -int EndToEndCompileRequest::getDependencyFileCount() -{ - auto frontEndReq = getFrontEndReq(); - auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); - return (int)program->getFileDependencies().getCount(); -} - -char const* EndToEndCompileRequest::getDependencyFilePath(int index) -{ - auto frontEndReq = getFrontEndReq(); - auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); - SourceFile* sourceFile = program->getFileDependencies()[index]; - return sourceFile->getPathInfo().hasFoundPath() - ? sourceFile->getPathInfo().getMostUniqueIdentity().getBuffer() - : "unknown"; -} - -int EndToEndCompileRequest::getTranslationUnitCount() -{ - return (int)getFrontEndReq()->translationUnits.getCount(); -} - -void const* EndToEndCompileRequest::getEntryPointCode(int entryPointIndex, size_t* outSize) -{ - // Zero the size initially, in case need to return nullptr for error. - if (outSize) - { - *outSize = 0; - } - - auto linkage = getLinkage(); - auto program = getSpecializedGlobalAndEntryPointsComponentType(); - - // TODO: We should really accept a target index in this API - Index targetIndex = 0; - auto targetCount = linkage->targets.getCount(); - if (targetIndex >= targetCount) - return nullptr; - auto targetReq = linkage->targets[targetIndex]; - - - if (entryPointIndex < 0) - return nullptr; - if (Index(entryPointIndex) >= program->getEntryPointCount()) - return nullptr; - auto entryPoint = program->getEntryPoint(entryPointIndex); - - auto targetProgram = program->getTargetProgram(targetReq); - if (!targetProgram) - return nullptr; - IArtifact* artifact = targetProgram->getExistingEntryPointResult(entryPointIndex); - if (!artifact) - { - return nullptr; - } - - ComPtr<ISlangBlob> blob; - SLANG_RETURN_NULL_ON_FAIL(artifact->loadBlob(ArtifactKeep::Yes, blob.writeRef())); - - if (outSize) - { - *outSize = blob->getBufferSize(); - } - - return (void*)blob->getBufferPointer(); -} - -SlangResult EndToEndCompileRequest::getCompileTimeProfile( - ISlangProfiler** compileTimeProfile, - bool shouldClear) -{ - if (compileTimeProfile == nullptr) - { - return SLANG_E_INVALID_ARG; - } - - SlangProfiler* profiler = new SlangProfiler(PerformanceProfiler::getProfiler()); - - if (shouldClear) - { - PerformanceProfiler::getProfiler()->clear(); - } - - ComPtr<ISlangProfiler> result(profiler); - *compileTimeProfile = result.detach(); - return SLANG_OK; -} - -static SlangResult _getEntryPointResult( - EndToEndCompileRequest* req, - int entryPointIndex, - int targetIndex, - ComPtr<IArtifact>& outArtifact) -{ - auto linkage = req->getLinkage(); - auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); - - Index targetCount = linkage->targets.getCount(); - if ((targetIndex < 0) || (targetIndex >= targetCount)) - { - return SLANG_E_INVALID_ARG; - } - auto targetReq = linkage->targets[targetIndex]; - - // Get the entry point count on the program, rather than (say) req->m_entryPoints.getCount() - // because - // 1) The entry point is fetched from the program anyway so must be consistent - // 2) The req may not have all entry points (for example when an entry point is in a module) - const Index entryPointCount = program->getEntryPointCount(); - - if ((entryPointIndex < 0) || (entryPointIndex >= entryPointCount)) - { - return SLANG_E_INVALID_ARG; - } - auto entryPointReq = program->getEntryPoint(entryPointIndex); - - auto targetProgram = program->getTargetProgram(targetReq); - if (!targetProgram) - return SLANG_FAIL; - - outArtifact = targetProgram->getExistingEntryPointResult(entryPointIndex); - return SLANG_OK; -} - -static SlangResult _getWholeProgramResult( - EndToEndCompileRequest* req, - int targetIndex, - ComPtr<IArtifact>& outArtifact) -{ - auto linkage = req->getLinkage(); - auto program = req->getSpecializedGlobalAndEntryPointsComponentType(); - - if (!program) - { - return SLANG_FAIL; - } - - Index targetCount = linkage->targets.getCount(); - if ((targetIndex < 0) || (targetIndex >= targetCount)) - { - return SLANG_E_INVALID_ARG; - } - auto targetReq = linkage->targets[targetIndex]; - - auto targetProgram = program->getTargetProgram(targetReq); - if (!targetProgram) - return SLANG_FAIL; - outArtifact = targetProgram->getExistingWholeProgramResult(); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::getEntryPointCodeBlob( - int entryPointIndex, - int targetIndex, - ISlangBlob** outBlob) -{ - if (!outBlob) - return SLANG_E_INVALID_ARG; - ComPtr<IArtifact> artifact; - SLANG_RETURN_ON_FAIL(_getEntryPointResult(this, entryPointIndex, targetIndex, artifact)); - SLANG_RETURN_ON_FAIL(artifact->loadBlob(ArtifactKeep::Yes, outBlob)); - - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary) -{ - if (!outSharedLibrary) - return SLANG_E_INVALID_ARG; - ComPtr<IArtifact> artifact; - SLANG_RETURN_ON_FAIL(_getEntryPointResult(this, entryPointIndex, targetIndex, artifact)); - SLANG_RETURN_ON_FAIL(artifact->loadSharedLibrary(ArtifactKeep::Yes, outSharedLibrary)); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::getTargetCodeBlob(int targetIndex, ISlangBlob** outBlob) -{ - if (!outBlob) - return SLANG_E_INVALID_ARG; - - ComPtr<IArtifact> artifact; - SLANG_RETURN_ON_FAIL(_getWholeProgramResult(this, targetIndex, artifact)); - SLANG_RETURN_ON_FAIL(artifact->loadBlob(ArtifactKeep::Yes, outBlob)); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::getTargetHostCallable( - int targetIndex, - ISlangSharedLibrary** outSharedLibrary) -{ - if (!outSharedLibrary) - return SLANG_E_INVALID_ARG; - - ComPtr<IArtifact> artifact; - SLANG_RETURN_ON_FAIL(_getWholeProgramResult(this, targetIndex, artifact)); - SLANG_RETURN_ON_FAIL(artifact->loadSharedLibrary(ArtifactKeep::Yes, outSharedLibrary)); - return SLANG_OK; -} - -char const* EndToEndCompileRequest::getEntryPointSource(int entryPointIndex) -{ - return (char const*)getEntryPointCode(entryPointIndex, nullptr); -} - -ISlangMutableFileSystem* EndToEndCompileRequest::getCompileRequestResultAsFileSystem() -{ - if (!m_containerFileSystem) - { - if (m_containerArtifact) - { - ComPtr<ISlangMutableFileSystem> fileSystem(new MemoryFileSystem); - - // Filter the containerArtifact into things that can be written - ComPtr<IArtifact> writeArtifact; - if (SLANG_SUCCEEDED( - ArtifactContainerUtil::filter(m_containerArtifact, writeArtifact)) && - writeArtifact) - { - if (SLANG_SUCCEEDED( - ArtifactContainerUtil::writeContainer(writeArtifact, "", fileSystem))) - { - m_containerFileSystem.swap(fileSystem); - } - } - } - } - - return m_containerFileSystem; -} - -void const* EndToEndCompileRequest::getCompileRequestCode(size_t* outSize) -{ - if (m_containerArtifact) - { - ComPtr<ISlangBlob> containerBlob; - if (SLANG_SUCCEEDED( - m_containerArtifact->loadBlob(ArtifactKeep::Yes, containerBlob.writeRef()))) - { - *outSize = containerBlob->getBufferSize(); - return containerBlob->getBufferPointer(); - } - } - - // Container blob does not have any contents - *outSize = 0; - return nullptr; -} - -SlangResult EndToEndCompileRequest::getContainerCode(ISlangBlob** outBlob) -{ - if (m_containerArtifact) - { - ComPtr<ISlangBlob> containerBlob; - if (SLANG_SUCCEEDED( - m_containerArtifact->loadBlob(ArtifactKeep::Yes, containerBlob.writeRef()))) - { - *outBlob = containerBlob.detach(); - return SLANG_OK; - } - } - return SLANG_FAIL; -} - -SlangResult EndToEndCompileRequest::loadRepro( - ISlangFileSystem* fileSystem, - const void* data, - size_t size) -{ - List<uint8_t> buffer; - SLANG_RETURN_ON_FAIL(ReproUtil::loadState((const uint8_t*)data, size, getSink(), buffer)); - - MemoryOffsetBase base; - base.set(buffer.getBuffer(), buffer.getCount()); - - ReproUtil::RequestState* requestState = ReproUtil::getRequest(buffer); - - SLANG_RETURN_ON_FAIL(ReproUtil::load(base, requestState, fileSystem, this)); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::saveRepro(ISlangBlob** outBlob) -{ - OwnedMemoryStream stream(FileAccess::Write); - - SLANG_RETURN_ON_FAIL(ReproUtil::saveState(this, &stream)); - - // Put the content of the stream in the blob - - List<uint8_t> data; - stream.swapContents(data); - - *outBlob = ListBlob::moveCreate(data).detach(); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::enableReproCapture() -{ - getLinkage()->setRequireCacheFileSystem(true); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::processCommandLineArguments( - char const* const* args, - int argCount) -{ - return parseOptions(this, argCount, args); -} - -SlangReflection* EndToEndCompileRequest::getReflection() -{ - auto linkage = getLinkage(); - auto program = getSpecializedGlobalAndEntryPointsComponentType(); - - // Note(tfoley): The API signature doesn't let the client - // specify which target they want to access reflection - // information for, so for now we default to the first one. - // - // TODO: Add a new `spGetReflectionForTarget(req, targetIndex)` - // so that we can do this better, and make it clear that - // `spGetReflection()` is shorthand for `targetIndex == 0`. - // - Slang::Index targetIndex = 0; - auto targetCount = linkage->targets.getCount(); - if (targetIndex >= targetCount) - return nullptr; - - auto targetReq = linkage->targets[targetIndex]; - auto targetProgram = program->getTargetProgram(targetReq); - - - DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - auto programLayout = targetProgram->getOrCreateLayout(&sink); - - return (SlangReflection*)programLayout; -} - -SlangResult EndToEndCompileRequest::getProgram(slang::IComponentType** outProgram) -{ - auto program = getSpecializedGlobalComponentType(); - *outProgram = Slang::ComPtr<slang::IComponentType>(program).detach(); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::getProgramWithEntryPoints(slang::IComponentType** outProgram) -{ - auto program = getSpecializedGlobalAndEntryPointsComponentType(); - *outProgram = Slang::ComPtr<slang::IComponentType>(program).detach(); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::getModule( - SlangInt translationUnitIndex, - slang::IModule** outModule) -{ - auto module = getFrontEndReq()->getTranslationUnit(translationUnitIndex)->getModule(); - - *outModule = Slang::ComPtr<slang::IModule>(module).detach(); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::getSession(slang::ISession** outSession) -{ - auto session = getLinkage(); - *outSession = Slang::ComPtr<slang::ISession>(session).detach(); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::getEntryPoint( - SlangInt entryPointIndex, - slang::IComponentType** outEntryPoint) -{ - auto entryPoint = getSpecializedEntryPointComponentType(entryPointIndex); - *outEntryPoint = Slang::ComPtr<slang::IComponentType>(entryPoint).detach(); - return SLANG_OK; -} - -SlangResult EndToEndCompileRequest::isParameterLocationUsed( - Int entryPointIndex, - Int targetIndex, - SlangParameterCategory category, - UInt spaceIndex, - UInt registerIndex, - bool& outUsed) -{ - if (!ShaderBindingRange::isUsageTracked((slang::ParameterCategory)category)) - return SLANG_E_NOT_AVAILABLE; - - ComPtr<IArtifact> artifact; - if (SLANG_FAILED(_getEntryPointResult( - this, - static_cast<int>(entryPointIndex), - static_cast<int>(targetIndex), - artifact))) - return SLANG_E_INVALID_ARG; - - if (!artifact) - return SLANG_E_NOT_AVAILABLE; - - // Find a rep - auto metadata = findAssociatedRepresentation<IArtifactPostEmitMetadata>(artifact); - if (!metadata) - return SLANG_E_NOT_AVAILABLE; - - return metadata->isParameterLocationUsed(category, spaceIndex, registerIndex, outUsed); -} - } // namespace Slang |
