diff options
Diffstat (limited to 'source/slang/slang-ir-explicit-global-context.cpp')
| -rw-r--r-- | source/slang/slang-ir-explicit-global-context.cpp | 523 |
1 files changed, 523 insertions, 0 deletions
diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp new file mode 100644 index 000000000..68f23461b --- /dev/null +++ b/source/slang/slang-ir-explicit-global-context.cpp @@ -0,0 +1,523 @@ +// slang-ir-explicit-global-context.cpp +#include "slang-ir-explicit-global-context.h" + +#include "slang-ir-insts.h" + +namespace Slang +{ + +// The job of this pass is take global-scope declarations +// that are actually scoped to a single shader thread or +// thread-group, and wrap them up in an explicit "context" +// type that gets passed between functions. + +struct IntroduceExplicitGlobalContextPass +{ + IRModule* m_module = nullptr; + CodeGenTarget m_target = CodeGenTarget::Unknown; + + SharedIRBuilder* m_sharedBuilder = nullptr; + IRStructType* m_contextStructType = nullptr; + IRPtrType* m_contextStructPtrType = nullptr; + + IRGlobalParam* m_globalUniformsParam = nullptr; + List<IRGlobalVar*> m_globalVars; + List<IRFunc*> m_entryPoints; + + void processModule() + { + SharedIRBuilder sharedBuilder(m_module); + m_sharedBuilder = &sharedBuilder; + + IRBuilder builder(&sharedBuilder); + + // The global context will be represneted by a `struct` + // type with a name hint of `KernelContext`. + // + m_contextStructType = builder.createStructType(); + builder.addNameHintDecoration(m_contextStructType, UnownedTerminatedStringSlice("KernelContext")); + + // The context will usually be passed around by pointer, + // so we get and cache that pointer type up front. + // + m_contextStructPtrType = builder.getPtrType(m_contextStructType); + + // The transformation we will perform will need to affect + // global variables, global shader parameters, and entry-point + // function (at the very least), and we start with an explicit + // pass to collect these entities into explicit lists to simplify + // looping over them later. + // + for( auto inst : m_module->getGlobalInsts() ) + { + switch( inst->op ) + { + case kIROp_GlobalVar: + { + // A "global variable" in HLSL (and thus Slang) is actually + // a weird kind of thread-local variable, and so it cannot + // actually be lowered to a global variable on targets where + // globals behave like, well, globals. + // + auto globalVar = cast<IRGlobalVar>(inst); + + // One important exception is that CUDA *does* support + // global variables with the `__shared__` qualifer, with + // semantics that exactly match HLSL/Slang `groupshared`. + // + // We thus need to skip processing of global variables + // that were marked `groupshared`. In our current IR, + // this is represented as a variable with the `@GroupShared` + // rate on its type. + // + if( m_target == CodeGenTarget::CUDASource ) + { + if( as<IRGroupSharedRate>(globalVar->getRate()) ) + continue; + } + + m_globalVars.add(globalVar); + } + break; + + case kIROp_GlobalParam: + { + // Global parameters are another HLSL/Slang concept + // that doesn't have a parallel in langauges like C/C++. + // + auto globalParam = cast<IRGlobalParam>(inst); + + + // One detail we need to be careful about is that as a result + // of legalizing the varying parameters of kernels, we can end + // up with global parameters for varying parameters on CUDA + // (e.g., to represent `threadIdx`. We thus skip any global-scope + // parameters that are varying instead of uniform. + // + auto layoutDecor = globalParam->findDecoration<IRLayoutDecoration>(); + SLANG_ASSERT(layoutDecor); + auto layout = as<IRVarLayout>(layoutDecor->getLayout()); + SLANG_ASSERT(layout); + if(isVaryingParameter(layout)) + continue; + + // Because of upstream passes, we expect there to be only a + // single global uniform parameter (at most). + // + // Note: If we ever changed out mind about the representation + // and wanted to support multiple global parameters, we could + // easily generalize this code to work with a list. + // + SLANG_ASSERT(!m_globalUniformsParam); + m_globalUniformsParam = globalParam; + } + break; + + case kIROp_Func: + { + // Every entry point function is going to need to be modified, + // so that it can explicit create the context that other + // operations will use. + + // We need to filter the IR functions to find only those + // that represent entry points. + // + auto func = cast<IRFunc>(inst); + if(!func->findDecoration<IREntryPointDecoration>()) + continue; + + m_entryPoints.add(func); + } + break; + } + } + + // Now that we've capture all the relevant global entities from the IR, + // we can being to transform them in an appropriate order. + // + // The first step will be to create fields in the `KernelContext` + // type to represent any global parameters or global variables. + // + // The keys for the fields that are created will be remembered + // in a dictionary, so that we can find them later based on + // the global parameter/variable. + // + if( m_globalUniformsParam ) + { + // For the parameter representing all the global uniform shader + // parameters, we create a field that exactly matches its type. + // + createContextStructField(m_globalUniformsParam, m_globalUniformsParam->getFullType()); + } + for( auto globalVar : m_globalVars ) + { + // A `IRGlobalVar` represents a pointer to where the variable is stored, + // so we need to create a field of the pointed-to type to represent it. + // + createContextStructField(globalVar, globalVar->getDataType()->getValueType()); + } + + // Once all the fields have been created, we can process the entry points. + // + // Each entry point will create a local `KernelContext` variable and + // initialize it based on the parameters passed to the entry point. + // + // The local variable introduced here will be registered as the representation + // of the context to be used in the body of the entry point. + // + for( auto entryPoint : m_entryPoints ) + { + createContextForEntryPoint(entryPoint); + } + + // Now that we've prepared all the entry points, we can make another + // pass over the global parameters/variables and start to replace + // their use sites with references to the fields of the context. + // + // Wherever a global parameter/variable is being referenced in a function, + // we will need to find or create a context value for that function + // to use. The context value for entry points has already been established + // above, but other functions will have an explicit context parameter + // added on demand. + // + if( m_globalUniformsParam ) + { + replaceUsesOfGlobalParam(m_globalUniformsParam); + } + for( auto globalVar : m_globalVars ) + { + replaceUsesOfGlobalVar(globalVar); + } + } + + // As noted above, we will maintain mappings to record + // the key for the context field created for a global + // variable parameter, and to record the context pointer + // value to use for a function. + // + Dictionary<IRInst*, IRStructKey*> m_mapInstToContextFieldKey; + Dictionary<IRFunc*, IRInst*> m_mapFuncToContextPtr; + + void createContextStructField(IRInst* originalInst, IRType* type) + { + // Creating a field in the context struct to represent + // `originalInst` is straightforward. + + IRBuilder builder(m_sharedBuilder); + builder.setInsertBefore(m_contextStructType); + + // We create a "key" for the new field, and then a field + // of the appropraite type. + // + auto key = builder.createStructKey(); + auto field = builder.createStructField(m_contextStructType, key, type); + + // If the original instruction had a name hint on it, + // then we transfer that name hint over to the key, + // so that the field will have the name of the former + // global variable/parameter. + // + if( auto nameHint = originalInst->findDecoration<IRNameHintDecoration>() ) + { + nameHint->insertAtStart(key); + } + + // Any other decorations on the original instruction + // (e.g., pertaining to layout) need to be transferred + // over to the field (not the key). + // + originalInst->transferDecorationsTo(field); + + // We end by making note of the key that was created + // for the instruction, so that we can use the key + // to access the field later. + // + m_mapInstToContextFieldKey.Add(originalInst, key); + } + + void createContextForEntryPoint(IRFunc* entryPointFunc) + { + // We can only introduce the explicit context into + // entry points that have definitions. + // + auto firstBlock = entryPointFunc->getFirstBlock(); + if(!firstBlock) + return; + + IRBuilder builder(m_sharedBuilder); + + // The code we introduce will all be added to the start + // of the first block of the function. + // + auto firstOrdinary = firstBlock->getFirstOrdinaryInst(); + builder.setInsertBefore(firstOrdinary); + + // If there was a global-scope uniform parameter before, + // then we need to introduce an explicit parameter onto + // each entry-point function to represent it. + // + IRParam* globalUniformsParam = nullptr; + if( m_globalUniformsParam ) + { + globalUniformsParam = builder.createParam(m_globalUniformsParam->getFullType()); + if( auto nameHint = m_globalUniformsParam->findDecoration<IRNameHintDecoration>() ) + { + builder.addNameHintDecoration(globalUniformsParam, nameHint->getNameOperand()); + } + + // The new parameter will be the last one in the + // parameter list of the entry point. + // + globalUniformsParam->insertBefore(firstOrdinary); + } + + // The `KernelContext` to use inside the entry point + // will be a local variable declared in the first block. + // + auto contextVarPtr = builder.emitVar(m_contextStructType); + addKernelContextNameHint(contextVarPtr); + m_mapFuncToContextPtr.Add(entryPointFunc, contextVarPtr); + + // If there is a global-scope uniform parameter, then + // we need to use our new explicit entry point parameter + // to inialize the corresponding field of the `KernelContext` + // before moving on with execution of the kernel body. + // + if(m_globalUniformsParam) + { + auto fieldKey = m_mapInstToContextFieldKey[m_globalUniformsParam]; + auto fieldType = globalUniformsParam->getFullType(); + auto fieldPtrType = builder.getPtrType(fieldType); + + // We compute the addrress of the field and store the + // value of the parameter into it. + // + auto fieldPtr = builder.emitFieldAddress(fieldPtrType, contextVarPtr, fieldKey); + builder.emitStore(fieldPtr, globalUniformsParam); + } + + // Note: at this point the `KernelContext` has additional + // fields for global variables that do not seem to have + // been initialized. + // + // Instead of making this pass take responsibility for initializing + // global variables, it is instead expected that clients will + // run the pass in `slang-ir-explicit-global-init` first, + // in order to move all initialization of globals into the + // entry point functions. + } + + void replaceUsesOfGlobalParam(IRGlobalParam* globalParam) + { + IRBuilder builder(m_sharedBuilder); + + // A global shader parameter was mapped to a field + // in the context structure, so we find the appropriate key. + // + auto key = m_mapInstToContextFieldKey[globalParam]; + + auto valType = globalParam->getFullType(); + auto ptrType = builder.getPtrType(valType); + + // We then iterate over the uses of the parameter, + // being careful to defend against the use/def information + // being changed while we walk it. + // + IRUse* nextUse = nullptr; + for( IRUse* use = globalParam->firstUse; use; use = nextUse ) + { + nextUse = use->nextUse; + + // At each use site, we need to look up the context + // pointer that is appropriate for that use. + // + auto user = use->getUser(); + auto contextParam = findOrCreateContextPtrForInst(user); + builder.setInsertBefore(user); + + // The value of the parameter can be produced by + // taking the address of the corresponding field + // in the context struct and loading from it. + // + auto ptr = builder.emitFieldAddress(ptrType, contextParam, key); + auto val = builder.emitLoad(valType, ptr); + use->set(val); + } + } + + void replaceUsesOfGlobalVar(IRGlobalVar* globalVar) + { + IRBuilder builder(m_sharedBuilder); + + // A global variable was mapped to a field + // in the context structure, so we find the appropriate key. + // + auto key = m_mapInstToContextFieldKey[globalVar]; + + auto ptrType = globalVar->getDataType(); + + // We then iterate over the uses of the variable, + // being careful to defend against the use/def information + // being changed while we walk it. + // + IRUse* nextUse = nullptr; + for( IRUse* use = globalVar->firstUse; use; use = nextUse ) + { + nextUse = use->nextUse; + + // At each use site, we need to look up the context + // pointer that is appropriate for that use. + // + auto user = use->getUser(); + auto contextParam = findOrCreateContextPtrForInst(user); + builder.setInsertBefore(user); + + // The address of the variable can be produced by + // taking the address of the corresponding field + // in the context struct. + // + auto ptr = builder.emitFieldAddress(ptrType, contextParam, key); + use->set(ptr); + } + } + + IRInst* findOrCreateContextPtrForInst(IRInst* inst) + { + // When looking up the context pointer to use for + // an instruction, we need to find the enclosing + // function and use whatever context pointer it uses. + // + for( IRInst* i = inst; i; i = i->getParent() ) + { + if( auto func = as<IRFunc>(i) ) + { + return findOrCreateContextPtrForFunc(func); + } + } + + // If a non-constant global entity is being referenced by + // something that is *not* nested under an IR function, then + // we are in trouble. + // + SLANG_UNEXPECTED("no outer func at use site for global"); + UNREACHABLE_RETURN(nullptr); + } + + IRInst* findOrCreateContextPtrForFunc(IRFunc* func) + { + // At this point we are being asked to either find or + // produce a context pointer for use inside `func`. + // + // If we already created such a pointer (perhaps because + // `func` is an entry point), then we are home free. + // + if( auto found = m_mapFuncToContextPtr.TryGetValue(func) ) + { + return *found; + } + + // Otherwise, we are going to need to introduce an + // explicit parameter to `func` to represent the + // context. + // + IRBuilder builder(m_sharedBuilder); + + // We can safely assume that `func` has a body, because + // otherwise we wouldn't be getting a request for the + // context pointer value to use in its body. + // + auto firstBlock = func->getFirstBlock(); + SLANG_ASSERT(firstBlock); + + // We create a new parameter at the end of the parameter + // list for `func`, with a type of `KernelContext*`. + // + IRParam* contextParam = builder.createParam(m_contextStructPtrType); + addKernelContextNameHint(contextParam); + contextParam->insertBefore(firstBlock->getFirstOrdinaryInst()); + + // The new parameter can be registerd as the context value + // to be used for `func` right away. + // + // Note: we register the value *before* modifying locations + // that call `func` to protect against a possible infinite-recursion + // situation if `func` is recursive along some path. + // + m_mapFuncToContextPtr.Add(func, contextParam); + + // Any code that calls `func` now needs to be updated to pass + // the context parameter. + // + // TODO: There is an issue here if `func` might be called + // dynamically, through something like a witness table. + // + List<IRUse*> uses; + for( auto use = func->firstUse; use; use = use->nextUse ) + { + // We will only fix up calls to `func`, and ignore + // other operations that might refer to it. + // + // TODO: We need to allow things like decorations that might + // refer to `func`, but this logic is also going to + // ignore things like witness tables that refer to `func`, + // or operations that pass `func` as a function pointer + // to a higher-order function. + // + auto call = as<IRCall>(use->getUser()); + if(!call) + continue; + + // We are going to construct a new call to `func` + // that has all of the arguments of the original call... + // + UInt originalArgCount = call->getArgCount(); + List<IRInst*> args; + for( UInt aa = 0; aa < originalArgCount; ++aa ) + { + args.add(call->getArg(aa)); + } + + // ... plus an additional argument representing + // the context pointer at the call site (note that + // this step leads to a potential for recursion in this pass; + // the maximum depth of the recursion is bounded by the + // maximum length of a cycle-free path through the call + // graph of the program). + // + args.add(findOrCreateContextPtrForInst(call)); + + // The new call will be emitted right before the old one, + // then used to replace it. + // + builder.setInsertBefore(call); + auto newCall = builder.emitCallInst(call->getFullType(), call->getCallee(), args); + call->replaceUsesWith(newCall); + call->removeAndDeallocate(); + } + + return contextParam; + } + + // Because we have multiple places where instructions representing + // the kernel context get introduced, we have factored out a subroutine + // for setting up the name hint to be used by those instructions. + // + void addKernelContextNameHint(IRInst* inst) + { + IRBuilder builder(m_sharedBuilder); + builder.addNameHintDecoration(inst, UnownedTerminatedStringSlice("kernelContext")); + } +}; + + /// Collect global-scope variables/paramters to form an explicit context that gets threaded through +void introduceExplicitGlobalContext( + IRModule* module, + CodeGenTarget target) +{ + IntroduceExplicitGlobalContextPass pass; + pass.m_module = module; + pass.m_target = target; + pass.processModule(); +} + +} |
