diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/ast-legalize.cpp | 100 | ||||
| -rw-r--r-- | source/slang/ast-legalize.h | 17 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 217 | ||||
| -rw-r--r-- | source/slang/ir-insts.h | 32 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 319 | ||||
| -rw-r--r-- | source/slang/ir.h | 1 | ||||
| -rw-r--r-- | source/slang/options.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 13 |
8 files changed, 514 insertions, 191 deletions
diff --git a/source/slang/ast-legalize.cpp b/source/slang/ast-legalize.cpp index 8e4da2717..afc8b31c8 100644 --- a/source/slang/ast-legalize.cpp +++ b/source/slang/ast-legalize.cpp @@ -2,6 +2,7 @@ #include "ast-legalize.h" #include "emit.h" +#include "ir-insts.h" #include "type-layout.h" #include "visitor.h" @@ -434,6 +435,10 @@ struct SharedLoweringContext CompileRequest* compileRequest; EntryPointRequest* entryPointRequest; + // The "main" module that is being translated (as opposed + // to any of the modules that might have been imported). + ModuleDecl* mainModuleDecl; + ExtensionUsageTracker* extensionUsageTracker; ProgramLayout* programLayout; @@ -463,6 +468,12 @@ struct SharedLoweringContext bool isRewrite = false; bool requiresCopyGLPositionToPositionPerView = false; + + // State for lowering imported declarations to IR as needed + IRSpecializationState* irSpecializationState = nullptr; + + // The actual result we want to return + LoweredEntryPoint result; }; static void attachLayout( @@ -2123,7 +2134,7 @@ struct LoweringVisitor RefPtr<ScopeStmt> loweredStmt, RefPtr<ScopeStmt> originalStmt) { - loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).As<ScopeDecl>(); + loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).getDecl()->As<ScopeDecl>(); LoweringVisitor subVisitor = *this; subVisitor.isBuildingStmt = true; @@ -2286,7 +2297,7 @@ struct LoweringVisitor ScopeStmt* originalStmt) { lowerStmtFields(loweredStmt, originalStmt); - loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).As<ScopeDecl>(); + loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).getDecl()->As<ScopeDecl>(); } // Child statements reference their parent statement, @@ -2586,7 +2597,7 @@ struct LoweringVisitor if (auto genSubst = dynamic_cast<GenericSubstitution*>(inSubstitutions)) { RefPtr<GenericSubstitution> result = new GenericSubstitution(); - result->genericDecl = translateDeclRef(genSubst->genericDecl).As<GenericDecl>(); + result->genericDecl = translateDeclRef(genSubst->genericDecl).getDecl()->As<GenericDecl>(); for (auto arg : genSubst->args) { result->args.Add(translateVal(arg)); @@ -2612,17 +2623,36 @@ struct LoweringVisitor } LoweredDeclRef translateDeclRef( - DeclRef<Decl> const& decl) + DeclRef<Decl> const& declRef) { LoweredDeclRef result; - result.decl = translateDeclRef(decl.decl); - result.substitutions = translateSubstitutions(decl.substitutions); + result.decl = translateDeclRefImpl(declRef); + result.substitutions = translateSubstitutions(declRef.substitutions); return result; } LoweredDecl translateDeclRef( - Decl* decl) + Decl* decl) { + return translateDeclRefImpl(DeclRef<Decl>(decl, nullptr)); + } + + // Try to find the module that (recursively) contains a given declaration. + ModuleDecl* findModuleForDecl( + Decl* decl) + { + for (auto dd = decl; dd; dd = dd->ParentDecl) + { + if (auto moduleDecl = dynamic_cast<ModuleDecl*>(dd)) + return moduleDecl; + } + return nullptr; + } + + LoweredDecl translateDeclRefImpl( + DeclRef<Decl> declRef) + { + Decl* decl = declRef.getDecl(); if (!decl) return LoweredDecl(); // We don't want to translate references to built-in declarations, @@ -2641,6 +2671,38 @@ struct LoweringVisitor if (getModifiedDecl(decl)->HasModifier<BuiltinModifier>()) return decl; + // If we are using the IR, and the declaration comes from + // an imported module (rather than the "rewrite-mode" module + // being translated), then we need to ensure that it gets lowered + // to IR instead. + if (shared->compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR) + { + auto parentModule = findModuleForDecl(decl); + if (parentModule && (parentModule != shared->mainModuleDecl)) + { + // Ensure that the IR code for the given declaration + // gets included in the output IR module, and *also* + // that we generate a suitable specialization of it + // if there are any substitutions in effect. + + getSpecializedGlobalValueForDeclRef( + shared->irSpecializationState, + declRef); + + // Remember that this declaration is handled via IR, + // rather than being present in the legalized AST. + shared->result.irDecls.Add(declRef.getDecl()); + + // We don't actually use the `IRGlobalValue` that the + // above operation returns, and instead just keep + // using the original declaration in the legalized + // AST. The step of mapping that declaration over + // to reference the IR symbol will happen later. + + return decl; + } + } + LoweredDecl loweredDecl; if (shared->loweredDecls.TryGetValue(decl, loweredDecl)) return loweredDecl; @@ -2649,10 +2711,10 @@ struct LoweringVisitor return lowerDecl(decl); } - RefPtr<ContainerDecl> translateDeclRef( - ContainerDecl* decl) + DeclRef<ContainerDecl> translateDeclRef( + DeclRef<ContainerDecl> declRef) { - return translateDeclRef((Decl*)decl).getDecl()->As<ContainerDecl>(); + return translateDeclRef(declRef).As<ContainerDecl>(); } LoweredDecl lowerDeclBase( @@ -2759,9 +2821,9 @@ struct LoweringVisitor { RefPtr<Decl> loweredParent; if (auto genericParentDecl = decl->ParentDecl->As<GenericDecl>()) - loweredParent = translateDeclRef(genericParentDecl->ParentDecl); + loweredParent = translateDeclRef(genericParentDecl->ParentDecl).getDecl(); else - loweredParent = translateDeclRef(decl->ParentDecl); + loweredParent = translateDeclRef(decl->ParentDecl).getDecl(); if (loweredParent) { auto layoutMod = loweredParent->FindModifier<ComputedLayoutModifier>(); @@ -3518,7 +3580,7 @@ struct LoweringVisitor if (auto parentModuleDecl = pp.As<ModuleDecl>()) { LoweringVisitor subVisitor = *this; - subVisitor.parentDecl = translateDeclRef(parentModuleDecl); + subVisitor.parentDecl = translateDeclRef(parentModuleDecl).getDecl()->As<ContainerDecl>(); subVisitor.isBuildingStmt = false; return subVisitor.lowerVarDeclCommonInner(decl, loweredDeclClass); @@ -4659,7 +4721,8 @@ LoweredEntryPoint lowerEntryPoint( EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target, - ExtensionUsageTracker* extensionUsageTracker) + ExtensionUsageTracker* extensionUsageTracker, + IRSpecializationState* irSpecializationState) { SharedLoweringContext sharedContext; sharedContext.compileRequest = entryPoint->compileRequest; @@ -4667,8 +4730,10 @@ LoweredEntryPoint lowerEntryPoint( sharedContext.programLayout = programLayout; sharedContext.target = target; sharedContext.extensionUsageTracker = extensionUsageTracker; + sharedContext.irSpecializationState = irSpecializationState; auto translationUnit = entryPoint->getTranslationUnit(); + sharedContext.mainModuleDecl = translationUnit->SyntaxNode; // Create a single module/program to hold all the lowered code // (with the exception of instrinsic/stdlib declarations, which @@ -4711,7 +4776,6 @@ LoweredEntryPoint lowerEntryPoint( sharedContext.entryPointLayout = visitor.findEntryPointLayout(entryPoint); - LoweredEntryPoint result; if (isRewrite) { for (auto dd : translationUnit->SyntaxNode->Members) @@ -4722,11 +4786,11 @@ LoweredEntryPoint lowerEntryPoint( else { auto loweredEntryPoint = visitor.lowerEntryPoint(entryPoint); - result.entryPoint = loweredEntryPoint; + sharedContext.result.entryPoint = loweredEntryPoint; } - result.program = sharedContext.loweredProgram; + sharedContext.result.program = sharedContext.loweredProgram; - return result; + return sharedContext.result; } } diff --git a/source/slang/ast-legalize.h b/source/slang/ast-legalize.h index 071ff6c51..9046e8df0 100644 --- a/source/slang/ast-legalize.h +++ b/source/slang/ast-legalize.h @@ -37,11 +37,12 @@ namespace Slang { - class EntryPointRequest; - class ProgramLayout; - class TranslationUnitRequest; - + class EntryPointRequest; struct ExtensionUsageTracker; + struct IRSpecializationState; + class ProgramLayout; + class TranslationUnitRequest; + struct LoweredEntryPoint { @@ -52,6 +53,11 @@ namespace Slang // contains the entry point and any // other declarations it uses RefPtr<ModuleDecl> program; + + // A set of declarations that are not present + // in the generated AST, and are instead stored + // in the companion IR module + HashSet<Decl*> irDecls; }; // Emit code for a single entry point, based on @@ -60,6 +66,7 @@ namespace Slang EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target, - ExtensionUsageTracker* extensionUsageTracker); + ExtensionUsageTracker* extensionUsageTracker, + IRSpecializationState* irSpecializationState); } #endif diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index d6f1f8e1a..4a084c714 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -102,6 +102,9 @@ struct SharedEmitContext Dictionary<IRBlock*, IRBlock*> irMapContinueTargetToLoopHead; HashSet<String> irTupleTypes; + + // Map used to tell AST lowering what decls are represented by IR. + HashSet<Decl*>* irDeclSetForAST = nullptr; }; struct EmitContext @@ -2942,6 +2945,19 @@ struct EmitVisitor void EmitDeclRef(DeclRef<Decl> declRef) { + // Are we emitting an AST in a context where some declarations + // are actually stored as IR code? + if(auto irDeclSet = context->shared->irDeclSetForAST) + { + Decl* decl = declRef.getDecl(); + if(irDeclSet->Contains(decl)) + { + emit(getMangledName(declRef)); + return; + } + } + + // TODO: need to qualify a declaration name based on parent scopes/declarations // Emit the name for the declaration itself @@ -6870,55 +6886,142 @@ String emitEntryPoint( // Depending on how the compiler was invoked, we may need to perform // some amount of preocessing on the code before we can emit it. // - // For our purposes, there are basically three different "modes" we - // care about: + // We try to partition the cases we need to handle into a few broad + // categories, each of which is reflected as a different code path + // below: + // + // 1. "Full rewriter" mode, where the user provides HLSL/GLSL, opts + // out of semantic checking, and doesn't make use of any Slang + // code via `import`. + // + // 2. "Partial rewriter" modes, where the user starts with HLSL/GLSL + // and opts out of checking for that code, but also imports some + // Slang code which may need cross-compilation. They may also + // need us to rewrite the AST for some of their HLSL/GLSL function + // bodies to make things work. This actually has two main sub-modes: // - // 1. "Full rewriter" mode, where the user provides HLSL/GLSL, and - // doesn't make use of any Slang code via `import`. + // a) "Without IR." If the user doesn't opt into using the IR, then + // the imported Slang code gets translated to the target languge + // via the same AST-to-AST pass that legalized the user's code. This + // mode will eventually go away, but it is the main one used right now. // - // 2. "Partial rewriter" mode, where the user starts with HLSL/GLSL, - // but also imports some Slang code, and may need us to rewrite - // their HLSL/GLSL function bodies to make things work. + // b) "With IR." If the user opts into using the IR, then we need to + // apply the AST-to-AST pass to their HLSL/GLSL code, but *also* use + // the IR to compile everything else. // - // 3. "Full" mode, where all of the input code is in Slang (and/or - // the subset of HLSL we can fully type-check). + // 3. "Full IR" mode, where we can assume all the input code is in Slang + // (or the subset of HLSL we understand) that has undergone full + // semantic checking, and the user has opted into using the IR. // - // We'll try to detect the cases here: + // We'll try to detect the cases here, starting with case (1): // - if((translationUnit->compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR) - && !(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING )) + if ((translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) + && translationUnit->compileRequest->loadedModulesList.Count() == 0) { - // This seems to be case (3), because the user is asking for full - // checking, and so we can assume we understand the code fully. + // The user has opted out of semantic checking for their own code + // (in the "main" module), and also hasn't `import`ed any Slang + // modules that would require cross-compilation. // - // The IR code for the module should already have been generated, - // so that we "just" need to specialize it as needed for the - // specific target and entry point in use. + // Our goal in this mode is to print out the AST we parsed and + // hopefully reproduce something as close to the original as possible. // - // The first pass is to extract the IR code of the entry point, - // and any other symbols it references. At the same time, - // we go ahead and select the target-specific version of - // any such functions if they are available. We also go - // ahead and apply the layout information (from `programLayout`) - // to the IR code (which previously had no layout). + // The only deviation we *want* from the original code is that we will + // add new parameter binding annotations. + + sharedContext.program = translationUnitSyntax; + visitor.EmitDeclsInContainerUsingLayout( + translationUnitSyntax, + globalStructLayout); + } + // + // Next we will check for case (2a): + else if (!(translationUnit->compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR)) + { + // This case means the user has opted out of using the IR (so we can't use the + // cases below), but they either turned on semantic checking *or* imported some + // Slang code, so they can't use the case above. // - // Note: it is important that we extract a *copy* of all the - // relevant IR, so that transformations we make for one - // entry point (or target) don't mess up the IR used for other - // entry points (targets). + // Note: This case should go away completely once the IR is able to be relied + // upon for all cross-compilation scenarios. + + // We will apply our AST-to-AST legalization pass before we emit + // any code, and we will emit code for the AST that comes out + // of this pass instead of the original. + + // We perform legalization of the program before emitting *anything*, + // because the lowering process might change how we emit some + // boilerplate at the start of the ouput for GLSL (e.g., what + // version we require). + + auto lowered = lowerEntryPoint( + entryPoint, + programLayout, + target, + &sharedContext.extensionUsageTracker, + nullptr); + sharedContext.program = lowered.program; + + // Note that we emit the main body code of the program *before* + // we emit any leading preprocessor directives for GLSL. + // This is to give the emit logic a change to make last-minute + // adjustments like changing the required GLSL version. // - auto lowered = specializeIRForEntryPoint( + // TODO: All such adjustments would be better handled during + // lowering, but that requires having a semantic rather than + // textual format for the HLSL->GLSL mapping. + visitor.EmitDeclsInContainer(lowered.program.Ptr()); + } + // + // The remaining cases all require the use of our IR, and so there + // are certain steps that need to be shared. + else + { + // We are going to create a fresh IR module that we will use to + // clone any code needed by the user's entry point. + IRSpecializationState* irSpecializationState = createIRSpecializationState( entryPoint, programLayout, - target, + target, targetRequest); + IRModule* irModule = getIRModule(irSpecializationState); + + LoweredEntryPoint lowered; + if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) + { + // We are in case (2b), where the main module is in unchecked + // HLSL/GLSL that we need to "rewrite," and any library code + // is in Slang that will need to be cross-compiled via the IR. + + // Initially, we will apply the AST-to-AST pass to legalize + // the user's code, much like we would for any other target. + // Along the way, this pass will discover any IR declarations + // that we use, and try to emit code for them into our IR module. + + lowered = lowerEntryPoint( + entryPoint, + programLayout, + target, + &sharedContext.extensionUsageTracker, + irSpecializationState); + } + else + { + // We are in case (3), where all of the code is in Slang, and + // has already been lowered to IR as part of the front-end + // compilation work. We thus start by cloning any code needed + // by the entry point over to our fresh IR module. + + specializeIRForEntryPoint( + irSpecializationState, + entryPoint); + } // If the user specified the flag that they want us to dump // IR, then do it here, for the target-specific, but // un-specialized IR. if (translationUnit->compileRequest->shouldDumpIR) { - dumpIR(lowered); + dumpIR(irModule); } // Next, we need to ensure that the code we emit for @@ -6927,7 +7030,7 @@ String emitEntryPoint( // none of our target supports generics, or interfaces, // so we need to specialize those away. // - specializeGenerics(lowered); + specializeGenerics(irModule); // Debugging code for IR transformations... #if 0 @@ -6941,7 +7044,7 @@ String emitEntryPoint( // we need to ensure that the code only uses types // that are legal on the chosen target. // - legalizeTypes(lowered); + legalizeTypes(irModule); // Debugging output of legalization #if 0 @@ -6950,50 +7053,24 @@ String emitEntryPoint( fprintf(stderr, "###\n"); #endif + // After all of the required optimization and legalization + // passes have been performed, we can emit target code from + // the IR module. + // // TODO: do we want to emit directly from IR, or translate the // IR back into AST for emission? + visitor.emitIRModule(&context, irModule); - visitor.emitIRModule(&context, lowered); + // If we are in case (2b) and the user *also* has AST-based code + // that we need to output, we'll do it now. + if (translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) + { + sharedContext.irDeclSetForAST = &lowered.irDecls; + visitor.EmitDeclsInContainer(lowered.program); + } // TODO: need to clean up the IR module here } - else if(!(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING ) || - translationUnit->compileRequest->loadedModulesList.Count() != 0) - { - // The user has `import`ed some Slang modules, and so we are in case (2) - // - // We need to apply a "rewriting" pass to the code the user wrote, - // and then emit the result. - - // We perform lowering of the program before emitting *anything*, - // because the lowering process might change how we emit some - // boilerplate at the start of the ouput for GLSL (e.g., what - // version we require). - auto lowered = lowerEntryPoint(entryPoint, programLayout, target, &sharedContext.extensionUsageTracker); - sharedContext.program = lowered.program; - - // Note that we emit the main body code of the program *before* - // we emit any leading preprocessor directives for GLSL. - // This is to give the emit logic a change to make last-minute - // adjustments like changing the required GLSL version. - // - // TODO: All such adjustments would be better handled during - // lowering, but that requires having a semantic rather than - // textual format for the HLSL->GLSL mapping. - visitor.EmitDeclsInContainer(lowered.program.Ptr()); - } - else - { - // We are in case (1). - // - // We should be able to just emit the AST we parsed right back out, - // along with whatever annotations we added along the way. - - sharedContext.program = translationUnitSyntax; - visitor.EmitDeclsInContainerUsingLayout( - translationUnitSyntax, - globalStructLayout); - } String code = sharedContext.sb.ProduceString(); sharedContext.sb.Clear(); diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index b95aea2fe..b5e1d5a1d 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -559,19 +559,45 @@ struct IRBuilder IRLayoutDecoration* addLayoutDecoration(IRValue* value, Layout* layout); }; -// Generate a clone of an IR module that is specialized for -// a particular entry point, target, etc. -IRModule* specializeIRForEntryPoint( +// + +// Interface to IR specialization for use when cloning target-specific +// IR as part of compiling an entry point. +// +// TODO: we really need to move all of this logic to its own files. + +// `IRSpecializationState` is used as an opaque type to wrap up all +// the data needed to perform IR specialization, without exposing +// implementation details. +struct IRSpecializationState; +IRSpecializationState* createIRSpecializationState( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, CodeGenTarget target, TargetRequest* targetReq); +void destroyIRSpecializationState(IRSpecializationState* state); +IRModule* getIRModule(IRSpecializationState* state); + +IRGlobalValue* getSpecializedGlobalValueForDeclRef( + IRSpecializationState* state, + DeclRef<Decl> const& declRef); + +// Clone the IR values reachable from the given entry point +// into the IR module assocaited with the specialization state. +// When multiple definitions of a symbol are found, the one +// that is best specialized for the given `targetReq` will be +// used. +void specializeIRForEntryPoint( + IRSpecializationState* state, + EntryPointRequest* entryPointRequest); // Find suitable uses of the `specialize` instruction that // can be replaced with references to specialized functions. void specializeGenerics( IRModule* module); +// + } #endif diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index ae7b71172..3a8aabd85 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -3041,9 +3041,40 @@ namespace Slang IRValue* clonedValue, IRValue* originalValue) { + if(!originalValue) + return; context->getClonedValues().Add(originalValue, clonedValue); } + // Information on values to use when registering a cloned value + struct IROriginalValuesForClone + { + IRValue* originalVal = nullptr; + IRSpecSymbol* sym = nullptr; + + IROriginalValuesForClone() {} + + IROriginalValuesForClone(IRValue* originalValue) + : originalVal(originalValue) + {} + + IROriginalValuesForClone(IRSpecSymbol* symbol) + : sym(symbol) + {} + }; + + void registerClonedValue( + IRSpecContextBase* context, + IRValue* clonedValue, + IROriginalValuesForClone const& originalValues) + { + registerClonedValue(context, clonedValue, originalValues.originalVal); + for( auto s = originalValues.sym; s; s = s->nextWithSameName ) + { + registerClonedValue(context, clonedValue, s->irGlobalValue); + } + } + void cloneDecorations( IRSpecContextBase* context, IRValue* clonedValue, @@ -3100,9 +3131,7 @@ namespace Slang }; - IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar); - IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc); - IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalVar); + IRGlobalValue* cloneGlobalValue(IRSpecContext* context, IRGlobalValue* originalVal); RefPtr<Substitutions> cloneSubstitutions( IRSpecContext* context, Substitutions* subst); @@ -3117,16 +3146,9 @@ namespace Slang switch (originalValue->op) { case kIROp_global_var: - return cloneGlobalVar(this, (IRGlobalVar*)originalValue); - break; - case kIROp_Func: - return cloneFunc(this, (IRFunc*)originalValue); - break; - case kIROp_witness_table: - return cloneWitnessTable(this, (IRWitnessTable*)originalValue); - break; + return cloneGlobalValue(this, (IRGlobalValue*) originalValue); case kIROp_boolConst: { @@ -3334,10 +3356,13 @@ namespace Slang IRGlobalValueWithCode* clonedValue, IRGlobalValueWithCode* originalValue); - IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar) + IRGlobalVar* cloneGlobalVarImpl( + IRSpecContext* context, + IRGlobalVar* originalVar, + IROriginalValuesForClone const& originalValues) { auto clonedVar = context->builder->createGlobalVar(context->maybeCloneType(originalVar->getType()->getValueType())); - registerClonedValue(context, clonedVar, originalVar); + registerClonedValue(context, clonedVar, originalValues); auto mangledName = originalVar->mangledName; clonedVar->mangledName = mangledName; @@ -3360,10 +3385,13 @@ namespace Slang return clonedVar; } - IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalTable) + IRWitnessTable* cloneWitnessTableImpl( + IRSpecContext* context, + IRWitnessTable* originalTable, + IROriginalValuesForClone const& originalValues) { auto clonedTable = context->builder->createWitnessTable(); - registerClonedValue(context, clonedTable, originalTable); + registerClonedValue(context, clonedTable, originalValues); auto mangledName = originalTable->mangledName; clonedTable->mangledName = mangledName; @@ -3384,6 +3412,13 @@ namespace Slang return clonedTable; } + IRWitnessTable* cloneWitnessTableWithoutRegistering( + IRSpecContext* context, + IRWitnessTable* originalTable) + { + return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone()); + } + void cloneGlobalValueWithCodeCommon( IRSpecContextBase* context, IRGlobalValueWithCode* clonedValue, @@ -3556,15 +3591,6 @@ namespace Slang return clonedFunc; } - - IRFunc* cloneSimpleFunc(IRSpecContextBase* context, IRFunc* originalFunc) - { - auto clonedFunc = context->builder->createFunc(); - registerClonedValue(context, clonedFunc, originalFunc); - cloneFunctionCommon(context, clonedFunc, originalFunc); - return clonedFunc; - } - // Get a string form of the target so that we can // use it to match against target-specialization modifiers // @@ -3687,57 +3713,111 @@ namespace Slang return false; } - IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc) + IRFunc* cloneFuncImpl( + IRSpecContext* context, + IRFunc* originalFunc, + IROriginalValuesForClone const& originalValues) { - // We are being asked to clone a particular function, but in - // the IR that comes out of the front-end there could still - // be multiple, target-specific, declarations of any given - // function, all of which share the same mangled name. - auto mangledName = originalFunc->mangledName; + auto clonedFunc = context->builder->createFunc(); + registerClonedValue(context, clonedFunc, originalValues); + cloneFunctionCommon(context, clonedFunc, originalFunc); + return clonedFunc; + } + + // Directly clone a global value, based on a single definition/declaration, `originalVal`. + // The symbol `sym` will thread together other declarations of the same value, and + // we will register the new value as the cloned version of all of those. + IRGlobalValue* cloneGlobalValueImpl( + IRSpecContext* context, + IRGlobalValue* originalVal, + IRSpecSymbol* sym) + { + if( !originalVal ) + { + SLANG_UNEXPECTED("cloning a null value"); + UNREACHABLE_RETURN(nullptr); + } + + switch( originalVal->op ) + { + case kIROp_Func: + return cloneFuncImpl(context, (IRFunc*) originalVal, sym); + + case kIROp_global_var: + return cloneGlobalVarImpl(context, (IRGlobalVar*)originalVal, sym); + case kIROp_witness_table: + return cloneWitnessTableImpl(context, (IRWitnessTable*)originalVal, sym); + + default: + SLANG_UNEXPECTED("unknown global value kind"); + UNREACHABLE_RETURN(nullptr); + } + + } + + // Clone a global value, which has the given `mangledName`. + // The `originalVal` is a known global IR value with that name, if one is available. + // (It is okay for this parameter to be null). + IRGlobalValue* cloneGlobalValueWithMangledName( + IRSpecContext* context, + String const& mangledName, + IRGlobalValue* originalVal) + { if(mangledName.Length() == 0) { - return cloneSimpleFunc(context, originalFunc); + // If there is no mangled name, then we assume this is a local symbol, + // and it can't possibly have multiple declarations. + return cloneGlobalValueImpl(context, originalVal, nullptr); } // - // We will scan through all of the available function declarations - // with the same mangled name as `originalFunc` and try + // We will scan through all of the available declarations + // with the same mangled name as `originalVal` and try // to pick the "best" one for our target. RefPtr<IRSpecSymbol> sym; - if( !context->getSymbols().TryGetValue(originalFunc->mangledName, sym) ) + if( !context->getSymbols().TryGetValue(mangledName, sym) ) { // This shouldn't happen! - SLANG_UNEXPECTED("no matching function registered"); - UNREACHABLE_RETURN(cloneSimpleFunc(context, originalFunc)); + SLANG_UNEXPECTED("no matching values registered"); + UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, nullptr)); } - // We will try to track the "best" definition we can find. - IRFunc* bestFunc = (IRFunc*) sym->irGlobalValue; - + // We will try to track the "best" declaration we can find. + // + // Generally, one declaration wil lbe better than another if it is + // more specialized for the chosen target. Otherwise, we simply favor + // definitions over declarations. + // + IRGlobalValue* bestVal = sym->irGlobalValue; for( auto ss = sym->nextWithSameName; ss; ss = ss->nextWithSameName ) { - IRFunc* newFunc = (IRFunc*) ss->irGlobalValue; - if(isBetterForTarget(context, newFunc, bestFunc)) - bestFunc = newFunc; + IRGlobalValue* newVal = ss->irGlobalValue; + if(isBetterForTarget(context, newVal, bestVal)) + bestVal = newVal; } - // All right, we are now in a position to clone the "best" - // definition that was found. - auto clonedFunc = context->builder->createFunc(); - - // The resulting function will be used as the cloned version - // of every declaration/definition in the original IR. - for( auto ss = sym; ss; ss = ss->nextWithSameName ) - { - registerClonedValue(context, clonedFunc, ss->irGlobalValue); - } + return cloneGlobalValueImpl(context, bestVal, sym); + } - // Clone the "best" definition into our context - cloneFunctionCommon(context, clonedFunc, bestFunc); + IRGlobalValue* cloneGlobalValueWithMangledName(IRSpecContext* context, String const& mangledName) + { + return cloneGlobalValueWithMangledName(context, mangledName, nullptr); + } - return clonedFunc; + // Clone a global value, where `originalVal` is one declaration/definition, but we might + // have to consider others, in order to find the "best" version of the symbol. + IRGlobalValue* cloneGlobalValue(IRSpecContext* context, IRGlobalValue* originalVal) + { + // We are being asked to clone a particular global value, but in + // the IR that comes out of the front-end there could still + // be multiple, target-specific, declarations of any given + // global value, all of which share the same mangled name. + return cloneGlobalValueWithMangledName( + context, + originalVal->mangledName, + originalVal); } StructTypeLayout* getGlobalStructLayout( @@ -3857,7 +3937,7 @@ namespace Slang globalVar = globalVar->getNextValue(); } SLANG_ASSERT(table); - table = cloneWitnessTable(context, (IRWitnessTable*)(table)); + table = cloneWitnessTableWithoutRegistering(context, (IRWitnessTable*)(table)); IRProxyVal * tableVal = new IRProxyVal(); tableVal->inst = table; paramSubst->witnessTables.Add(KeyValuePair<RefPtr<Type>, RefPtr<Val>>(subtypeWitness->sup, tableVal)); @@ -3868,55 +3948,57 @@ namespace Slang return globalParamSubst; } - IRModule* specializeIRForEntryPoint( + struct IRSpecializationState + { + ProgramLayout* programLayout; + CodeGenTarget target; + TargetRequest* targetReq; + + IRModule* irModule = nullptr; + RefPtr<ProgramLayout> newProgramLayout; + + IRSharedSpecContext sharedContextStorage; + IRSpecContext contextStorage; + + IRSharedSpecContext* getSharedContext() { return &sharedContextStorage; } + IRSpecContext* getContext() { return &contextStorage; } + }; + + IRSpecializationState* createIRSpecializationState( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, CodeGenTarget target, TargetRequest* targetReq) { + IRSpecializationState* state = new IRSpecializationState(); + + state->programLayout = programLayout; + state->target = target; + state->targetReq = targetReq; + + auto compileRequest = entryPointRequest->compileRequest; - auto session = compileRequest->mSession; auto translationUnit = entryPointRequest->getTranslationUnit(); auto originalIRModule = translationUnit->irModule; - if (!originalIRModule) - { - // We should already have emitted IR for the original - // translation unit, and it we don't have it, then - // we are now in trouble. - return nullptr; - } - - // We now need to start cloning IR symbols from `originalIRModule` - // into a fresh IR module for this entry point. Along the way we need to: - // - // 1. Attach layout information from `programLayout` and/or `entryPointLayout` - // onto the cloned IR symbols, to drive later code generation. - // - // 2. In cases where a function might have multiple target-specific definitions, - // we need to pick the "best" one for the chosen code generation target. - // - - IRSharedSpecContext sharedContextStorage; + auto sharedContext = state->getSharedContext(); initializeSharedSpecContext( - &sharedContextStorage, + sharedContext, compileRequest->mSession, nullptr, originalIRModule); + state->irModule = sharedContext->module; // We also need to attach the IR definitions for symbols from // any loaded modules: for (auto loadedModule : compileRequest->loadedModulesList) { - insertGlobalValueSymbols(&sharedContextStorage, loadedModule->irModule); + insertGlobalValueSymbols(sharedContext, loadedModule->irModule); } - // any loaded modules - - IRSpecContext contextStorage; - IRSpecContext* context = &contextStorage; - context->shared = &sharedContextStorage; - context->builder = &sharedContextStorage.builderStorage; + auto context = state->getContext(); + context->shared = sharedContext; + context->builder = &sharedContext->builderStorage; context->target = target; // Create the GlobalGenericParamSubstitution for substituting global generic types @@ -3928,7 +4010,7 @@ namespace Slang // now specailize the program layout using the substitution RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout(targetReq, programLayout, globalParamSubst); - auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest); + state->newProgramLayout = newProgramLayout; // Next, we want to optimize lookup for layout infromation // associated with global declarations, so that we can @@ -3940,6 +4022,69 @@ namespace Slang context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); } + return state; + } + + void destroyIRSpecializationState(IRSpecializationState* state) + { + delete state; + } + + IRModule* getIRModule(IRSpecializationState* state) + { + return state->irModule; + } + + IRGlobalValue* getSpecializedGlobalValueForDeclRef( + IRSpecializationState* state, + DeclRef<Decl> const& declRef) + { + // We will start be ensuring that we have code for + // the declaration itself. + auto decl = declRef.getDecl(); + auto mangledDeclName = getMangledName(decl); + + IRGlobalValue* irDeclVal = cloneGlobalValueWithMangledName( + state->getContext(), + mangledDeclName); + if(!irDeclVal) + return nullptr; + + // Now we need to deal with specializing the given + // IR value based on the substitutions applied to + // our declaration reference. + + if(!declRef.substitutions) + return irDeclVal; + + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(nullptr); + } + + void specializeIRForEntryPoint( + IRSpecializationState* state, + EntryPointRequest* entryPointRequest) + { + auto target = state->target; + + auto compileRequest = entryPointRequest->compileRequest; + auto session = compileRequest->mSession; + auto translationUnit = entryPointRequest->getTranslationUnit(); + auto originalIRModule = translationUnit->irModule; + if (!originalIRModule) + { + // We should already have emitted IR for the original + // translation unit, and it we don't have it, then + // we are now in trouble. + return; + } + + auto context = state->getContext(); + auto newProgramLayout = state->newProgramLayout; + + auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest); + + // Next, we make sure to clone the global value for // the entry point function itself, and rely on // this step to recursively copy over anything else @@ -3964,8 +4109,6 @@ namespace Slang default: break; } - - return sharedContextStorage.module; } // diff --git a/source/slang/ir.h b/source/slang/ir.h index f2069e7c3..a6f5b30a1 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -466,6 +466,7 @@ void printSlangIRAssembly(StringBuilder& builder, IRModule* module); String getSlangIRAssembly(IRModule* module); void dumpIR(IRModule* module); + } diff --git a/source/slang/options.cpp b/source/slang/options.cpp index 452e7c439..97deeb544 100644 --- a/source/slang/options.cpp +++ b/source/slang/options.cpp @@ -241,6 +241,12 @@ struct OptionsParser int argc, char const* const* argv) { + // Copy some state out of the current request, in case we've been called + // after some other initialization has been performed. + flags = requestImpl->compileFlags; + + // + char const* const* argCursor = &argv[0]; char const* const* argEnd = &argv[argc]; while (argCursor != argEnd) diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 204313e84..ea86663ea 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -203,15 +203,11 @@ void CompileRequest::generateIR() // in isolation. for( auto& translationUnit : translationUnits ) { - // If the user opted out of semantic checking for - // the translation unit, then IR code generation - // is not in general even possible; there might - // be semantics errors (diagnosed or not) in the - // code, and we don't want to deal with those. - if (translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) + // Also skip IR generation if semantic checking is turned off + // for a given translation unit. + if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) continue; - // Okay, we seem to be in the clear now. translationUnit->irModule = generateIRForTranslationUnit(translationUnit); } } @@ -462,6 +458,7 @@ RefPtr<ModuleDecl> CompileRequest::loadModule( // semantic checking to be enabled. // // TODO: decide which options, if any, should be inherited. + translationUnit->compileFlags = this->compileFlags & (SLANG_COMPILE_FLAG_USE_IR); RefPtr<SourceFile> sourceFile = getSourceManager()->allocateSourceFile(path, source); @@ -486,6 +483,8 @@ void CompileRequest::handlePoundImport( RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(); translationUnit->compileRequest = this; + translationUnit->compileFlags = this->compileFlags & (SLANG_COMPILE_FLAG_USE_IR); + // Imported code is always native Slang code RefPtr<Scope> languageScope = mSession->slangLanguageScope; |
