diff options
Diffstat (limited to 'source/slang/slang.cpp')
| -rw-r--r-- | source/slang/slang.cpp | 108 |
1 files changed, 98 insertions, 10 deletions
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index fce90d612..bfc77b2e3 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -38,13 +38,15 @@ namespace Slang { // Allocate static const storage for the various interface IDs that the Slang API needs to expose -static const Guid IID_ISlangUnknown = SLANG_UUID_ISlangUnknown; -static const Guid IID_ISlangBlob = SLANG_UUID_ISlangBlob; -static const Guid IID_ISession = SLANG_UUID_ISession; -static const Guid IID_IGlobalSession = SLANG_UUID_IGlobalSession; -static const Guid IID_IModule = SLANG_UUID_IModule; +static const Guid IID_IComponentType = SLANG_UUID_IComponentType; +static const Guid IID_IEntryPoint = SLANG_UUID_IEntryPoint; +static const Guid IID_IGlobalSession = SLANG_UUID_IGlobalSession; +static const Guid IID_IModule = SLANG_UUID_IModule; +static const Guid IID_ISession = SLANG_UUID_ISession; +static const Guid IID_ISlangBlob = SLANG_UUID_ISlangBlob; +static const Guid IID_ISlangUnknown = SLANG_UUID_ISlangUnknown; -Session::Session() +void Session::init() { ::memset(m_downstreamCompilerLocators, 0, sizeof(m_downstreamCompilerLocators)); DownstreamCompilerUtil::setDefaultLocators(m_downstreamCompilerLocators); @@ -77,6 +79,16 @@ Session::Session() m_builtinLinkage = new Linkage(this); + // Because the `Session` retains the builtin `Linkage`, + // we need to make sure that the parent pointer inside + // `Linkage` doesn't create a retain cycle. + // + // This operation ensures that the parent pointer will + // just be a raw pointer, so that the builtin linkage + // doesn't keep the parent session alive. + // + m_builtinLinkage->_stopRetainingParentSession(); + // Initialize representations of some very basic types: initializeTypes(); @@ -437,6 +449,7 @@ Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target) Linkage::Linkage(Session* session) : m_session(session) + , m_retainedSession(session) , m_sourceManager(&m_defaultSourceManager) { getNamePool()->setRootNamePool(session->getRootNamePool()); @@ -1704,9 +1717,9 @@ Module::Module(Linkage* linkage) ISlangUnknown* Module::getInterface(const Guid& guid) { - if(guid == IID_ISlangUnknown || guid == IID_IModule) + if(guid == IID_IModule) return asExternal(this); - return nullptr; + return Super::getInterface(guid); } void Module::addModuleDependency(Module* module) @@ -1725,14 +1738,51 @@ void Module::setModuleDecl(ModuleDecl* moduleDecl) m_moduleDecl = moduleDecl; } -// ComponentType +RefPtr<EntryPoint> Module::findEntryPointByName(UnownedStringSlice const& name) +{ + // TODO: We should consider having this function be expanded to be able + // to look up and validate possible entry-point functions in teh module + // even if they were not marked with `[shader(...)]` in the source code. + // + // With such a change the function would probably need to accept a stage + // to use and a sink to write validation errors to. + + for(auto entryPoint : m_entryPoints) + { + if(entryPoint->getName()->text.getUnownedSlice() == name) + return entryPoint; + } + + return nullptr; +} + +void Module::_addEntryPoint(EntryPoint* entryPoint) +{ + m_entryPoints.add(entryPoint); +} -static const Guid IID_IComponentType = SLANG_UUID_IComponentType; + +// ComponentType ComponentType::ComponentType(Linkage* linkage) : m_linkage(linkage) {} +ComponentType* asInternal(slang::IComponentType* inComponentType) +{ + // Note: we use a `queryInterface` here instead of just a `static_cast` + // to ensure that the `IComponentType` we get is the preferred/canonical + // one, which shares its address with the `ComponentType`. + // + // TODO: An alternative choice here would be to have a "magic" IID that + // we pass into `queryInterface` that returns the `ComponentType` directly + // (without even `addRef`-ing it). + // + ComPtr<slang::IComponentType> componentType; + inComponentType->queryInterface(IID_IComponentType, (void**) componentType.writeRef()); + return static_cast<ComponentType*>(componentType.get()); +} + ISlangUnknown* ComponentType::getInterface(Guid const& guid) { if(guid == IID_ISlangUnknown @@ -1873,6 +1923,28 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize( return SLANG_OK; } +RefPtr<ComponentType> fillRequirements( + ComponentType* inComponentType); + +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) +{ + // TODO: It should be possible for `fillRequirements` to fail, + // in cases where we have a dependency that can't be automatically + // resolved. + // + SLANG_UNUSED(outDiagnostics); + + auto linked = fillRequirements(this); + if(!linked) + return SLANG_FAIL; + + *outLinkedComponentType = ComPtr<slang::IComponentType>(linked).detach(); + return SLANG_OK; +} + + /// Visitor used by `ComponentType::enumerateModules` struct EnumerateModulesVisitor : ComponentTypeVisitor { @@ -2488,6 +2560,7 @@ Session::~Session() SLANG_API SlangSession* spCreateSession(const char*) { Slang::RefPtr<Slang::Session> session(new Slang::Session()); + session->init(); // Will be returned with a refcount of 1 return asExternal(session.detach()); } @@ -2500,6 +2573,7 @@ SLANG_API SlangResult slang_createGlobalSession( return SLANG_E_NOT_IMPLEMENTED; Slang::Session* globalSession = new Slang::Session(); + globalSession->init(); Slang::ComPtr<slang::IGlobalSession> result(Slang::asExternal(globalSession)); *outGlobalSession = result.detach(); return SLANG_OK; @@ -3487,6 +3561,20 @@ SLANG_API SlangResult spCompileRequest_getProgram( return SLANG_OK; } +SLANG_API SlangResult spCompileRequest_getModule( + SlangCompileRequest* request, + SlangInt translationUnitIndex, + slang::IModule** outModule) +{ + if( !request ) return SLANG_ERROR_INVALID_PARAMETER; + auto req = Slang::asInternal(request); + + auto module = req->getFrontEndReq()->getTranslationUnit(translationUnitIndex)->getModule(); + + *outModule = Slang::ComPtr<slang::IModule>(module).detach(); + return SLANG_OK; +} + SLANG_API SlangResult spCompileRequest_getEntryPoint( SlangCompileRequest* request, SlangInt entryPointIndex, |
