#include "slang-ir-specialize-address-space.h" #include "slang-ir-clone.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir.h" namespace Slang { struct AddressSpaceContext : public AddressSpaceSpecializationContext { IRModule* module; Dictionary mapInstToAddrSpace; InitialAddressSpaceAssigner* addrSpaceAssigner; HashSet functionsToConsiderRemoving; AddressSpaceContext(IRModule* inModule, InitialAddressSpaceAssigner* inAddrSpaceAssigner) : module(inModule), addrSpaceAssigner(inAddrSpaceAssigner) { } AddressSpace getAddressSpaceFromVarType(IRInst* type) { return addrSpaceAssigner->getAddressSpaceFromVarType(type); } AddressSpace getLeafInstAddressSpace(IRInst* inst) { return addrSpaceAssigner->getLeafInstAddressSpace(inst); } AddressSpace getAddrSpace(IRInst* inst) override { auto addrSpace = mapInstToAddrSpace.tryGetValue(inst); if (addrSpace) return *addrSpace; return AddressSpace::Generic; } List workList; struct FuncSpecializationKey { private: IRFunc* func; List argAddrSpaces; HashCode hashCode; public: IRFunc* getFunc() const { return func; } ArrayView getArgAddrSpaces() const { return argAddrSpaces.getArrayView(); } FuncSpecializationKey() = default; FuncSpecializationKey(IRFunc* func, List argAddrSpaces) : func(func), argAddrSpaces(argAddrSpaces) { Hasher hasher; hasher.addHash(Slang::getHashCode(func)); for (auto addrSpace : argAddrSpaces) { hasher.addHash((HashCode)addrSpace); } hashCode = hasher.getResult(); } bool operator==(const FuncSpecializationKey& key) const { if (func != key.func) return false; if (argAddrSpaces.getCount() != key.argAddrSpaces.getCount()) return false; for (Index i = 0; i < argAddrSpaces.getCount(); i++) { if (argAddrSpaces[i] != key.argAddrSpaces[i]) return false; } return true; } HashCode getHashCode() const { return hashCode; } }; Dictionary functionSpecializations; IRFunc* specializeFunc(const FuncSpecializationKey& key) { auto func = key.getFunc(); IRCloneEnv cloneEnv; IRBuilder builder(module); // First, clone the function body. builder.setInsertBefore(func); auto specializedFunc = as(cloneInst(&cloneEnv, &builder, func)); // Update the parameter types with new address spaces in the specialized function. Index paramIndex = 0; for (auto param : specializedFunc->getParams()) { auto paramType = param->getFullType(); auto ptrType = as(paramType); if (ptrType) { auto paramAddrSpace = key.getArgAddrSpaces()[paramIndex]; auto newParamType = builder.getPtrType( ptrType->getOp(), ptrType->getValueType(), ptrType->getAccessQualifier(), paramAddrSpace); param->setFullType(newParamType); mapInstToAddrSpace[param] = paramAddrSpace; } paramIndex++; } // Update the function type. fixUpFuncType(specializedFunc); functionSpecializations[key] = specializedFunc; return specializedFunc; } AddressSpace getFuncResultAddrSpace(IRFunc* callee) { auto funcType = as(callee->getDataType()); return getAddressSpaceFromVarType(funcType->getResultType()); } // Return true if the address space of the function return type is changed. bool processFunction(IRFunc* func) { bool retValAddrSpaceChanged = false; bool changed = true; while (changed) { changed = false; for (auto block : func->getBlocks()) { bool isFirstBlock = block == func->getFirstBlock(); for (auto inst : block->getChildren()) { // If we have already assigned an address space to this instruction, then skip // it. if (mapInstToAddrSpace.containsKey(inst)) { // TODO: if the inst is a phi node, we need to check if the address space of // the phi arguments is consistent. If not, then we need to report an error. // For now, we just skip the checks. continue; } // If the inst already has a pointer/pointer-like type with explicit address // space, then use it. auto addrSpaceFromType = addrSpaceAssigner->getAddressSpaceFromVarType(inst->getDataType()); if (addrSpaceFromType != AddressSpace::Generic) { mapInstToAddrSpace[inst] = addrSpaceFromType; changed = true; // Don't return early if the inst itself is a call, as we may still need to // specialize it down below. if (inst->getOp() != kIROp_Call) continue; } // Try to assign an address space based on the instruction type, and specialize // calls. switch (inst->getOp()) { case kIROp_Var: case kIROp_RWStructuredBufferGetElementPtr: case kIROp_Load: { // The address space of these insts should be assigned by the initial // address space assigner. AddressSpace addrSpace = AddressSpace::Generic; if (addrSpaceAssigner->tryAssignAddressSpace(inst, addrSpace)) { mapInstToAddrSpace[inst] = addrSpace; changed = true; } break; } case kIROp_GetElementPtr: case kIROp_FieldAddress: case kIROp_GetOffsetPtr: case kIROp_BitCast: if (!mapInstToAddrSpace.containsKey(inst)) { auto addrSpace = getAddrSpace(inst->getOperand(0)); if (addrSpace != AddressSpace::Generic) { mapInstToAddrSpace[inst] = addrSpace; changed = true; } } break; case kIROp_Store: break; case kIROp_Param: if (!isFirstBlock) { auto phiArgs = getPhiArgs(inst); AddressSpace addrSpace = AddressSpace::Generic; for (auto arg : phiArgs) { auto argAddrSpace = getAddrSpace(arg); if (argAddrSpace != AddressSpace::Generic) { if (addrSpace != AddressSpace::Generic && addrSpace != argAddrSpace) { // TODO: this is an error in user code, because the // address spaces of the phi arguments don't match. } addrSpace = argAddrSpace; } } if (addrSpace != AddressSpace::Generic) { mapInstToAddrSpace[inst] = addrSpace; changed = true; } break; } break; case kIROp_Call: { auto callInst = as(inst); auto callee = as(inst->getOperand(0)); if (callee) { List argAddrSpaces; bool hasSpecializableArg = false; for (UInt i = 0; i < callInst->getArgCount(); i++) { auto arg = callInst->getArg(i); auto addrSpace = getAddrSpace(arg); argAddrSpaces.add(addrSpace); if (addrSpace != AddressSpace::Generic) { hasSpecializableArg = true; } } if (!hasSpecializableArg) { workList.add(callee); break; } // If callee doesn't have a body, don't specialize. if (!callee->getFirstBlock()) break; FuncSpecializationKey key(callee, argAddrSpaces); IRFunc* specializedCallee = nullptr; if (IRFunc** specializedFunc = functionSpecializations.tryGetValue(key)) { specializedCallee = *specializedFunc; } else { specializedCallee = specializeFunc(key); workList.add(specializedCallee); } IRBuilder builder(callInst); builder.setInsertBefore(callInst); if (specializedCallee != callInst->getCallee()) { callInst = as(builder.replaceOperand( callInst->getOperands(), specializedCallee)); // At this point, the original callee may be left without uses. functionsToConsiderRemoving.add(callee); } auto callResultAddrSpace = getFuncResultAddrSpace(specializedCallee); if (callResultAddrSpace != AddressSpace::Generic) { mapInstToAddrSpace[callInst] = callResultAddrSpace; changed = true; } } } break; case kIROp_Return: { auto retVal = inst->getOperand(0); auto addrSpace = getAddrSpace(retVal); if (addrSpace != AddressSpace::Generic) { auto funcType = as(func->getDataType()); AddressSpace resultAddrSpace = getFuncResultAddrSpace(func); if (resultAddrSpace != addrSpace) { auto ptrResultType = as(funcType->getResultType()); SLANG_ASSERT(ptrResultType); IRBuilder builder(func); auto newResultType = builder.getPtrType( ptrResultType->getOp(), ptrResultType->getValueType(), ptrResultType->getAccessQualifier(), addrSpace); fixUpFuncType(func, newResultType); retValAddrSpaceChanged = true; } } } break; } } } } return retValAddrSpaceChanged; } static void setDataType(IRInst* inst, IRType* dataType) { auto rate = inst->getRate(); if (!rate) { inst->setFullType(dataType); return; } IRBuilder builder(inst); builder.setInsertBefore(inst); auto newType = builder.getRateQualifiedType(rate, dataType); inst->setFullType(newType); } void applyAddressSpaceToInstType() { for (auto [inst, addrSpace] : mapInstToAddrSpace) { auto ptrType = as(inst->getDataType()); if (ptrType) { if (ptrType->getAddressSpace() != addrSpace) { IRBuilder builder(inst); auto newType = builder.getPtrType( ptrType->getOp(), ptrType->getValueType(), ptrType->getAccessQualifier(), addrSpace); setDataType(inst, newType); } } } } void processModule() { for (auto globalInst : module->getGlobalInsts()) { auto addrSpace = getLeafInstAddressSpace(globalInst); if (addrSpace != AddressSpace::Generic) { mapInstToAddrSpace[globalInst] = addrSpace; } if (auto func = as(globalInst)) { if (func->findDecoration()) workList.add(func); } } HashSet newWorkList; while (workList.getCount()) { for (Index i = 0; i < workList.getCount(); i++) { auto func = workList[i]; bool resultTypeChanged = processFunction(func); if (resultTypeChanged) { for (auto use = func->firstUse; use; use = use->nextUse) { if (auto callInst = as(use->getUser())) { newWorkList.add(getParentFunc(callInst)); } } } } workList.clear(); for (auto f : newWorkList) workList.add(f); } applyAddressSpaceToInstType(); for (IRFunc* func : functionsToConsiderRemoving) { SLANG_ASSERT(!func->findDecoration()); if (!func->hasUses()) func->removeAndDeallocate(); } } }; void specializeAddressSpace(IRModule* module, InitialAddressSpaceAssigner* addrSpaceAssigner) { AddressSpaceContext context(module, addrSpaceAssigner); context.processModule(); } void propagateAddressSpaceFromInsts(List&& workList) { HashSet visited; auto addUserToWorkList = [&](IRInst* inst) { for (auto use = inst->firstUse; use; use = use->nextUse) { auto user = use->getUser(); if (visited.add(user)) workList.add(user); } }; for (auto item : workList) { visited.add(item); } for (Index i = 0; i < workList.getCount(); i++) { auto inst = workList[i]; IRBuilder builder(inst); auto instPtrType = as(inst->getDataType()); if (!instPtrType) continue; for (auto use = inst->firstUse; use; use = use->nextUse) { auto user = use->getUser(); builder.setInsertBefore(user); switch (user->getOp()) { case kIROp_Loop: case kIROp_UnconditionalBranch: { auto branch = as(user); UIndex phiIndex = (UIndex)(use - branch->getArgs()); auto param = getParamAt(branch->getTargetBlock(), phiIndex); if (!param) continue; user = param; break; } } switch (user->getOp()) { case kIROp_FieldAddress: case kIROp_GetElementPtr: case kIROp_GetOffsetPtr: case kIROp_Param: { auto valueType = tryGetPointedToType(&builder, user->getDataType()); if (!valueType) continue; auto newType = builder.getPtrTypeWithAddressSpace(valueType, instPtrType); if (newType != user->getDataType()) { user->setFullType(newType); addUserToWorkList(user); } break; } } } } } AddressSpace NoOpInitialAddressSpaceAssigner::getAddressSpaceFromVarType(IRInst* type) { if (auto ptrType = as(type)) { if (ptrType->hasAddressSpace()) return ptrType->getAddressSpace(); } return AddressSpace::Generic; } AddressSpace NoOpInitialAddressSpaceAssigner::getLeafInstAddressSpace(IRInst*) { return AddressSpace::Generic; } } // namespace Slang