diff options
Diffstat (limited to 'source/slang/ir.cpp')
| -rw-r--r-- | source/slang/ir.cpp | 193 |
1 files changed, 139 insertions, 54 deletions
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 994ac82ff..7318bff4c 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -7,6 +7,13 @@ namespace Slang { + struct IRSpecContext; + + IRGlobalValue* cloneGlobalValueWithMangledName( + IRSpecContext* context, + String const& mangledName, + IRGlobalValue* originalVal); + static const IROpInfo kIROpInfos[] = { @@ -3065,6 +3072,9 @@ namespace Slang struct IRSharedSpecContext { + // The code-generation target in use + CodeGenTarget target; + // The specialized module we are building IRModule* module; @@ -3091,6 +3101,10 @@ namespace Slang struct IRSpecContextBase { + // A map from the mangled name of a global variable + // to the layout to use for it. + Dictionary<String, VarLayout*> globalVarLayouts; + IRSharedSpecContext* shared; IRSharedSpecContext* getShared() { return shared; } @@ -3224,13 +3238,6 @@ namespace Slang struct IRSpecContext : IRSpecContextBase { - // The code-generation target in use - CodeGenTarget target; - - // A map from the mangled name of a global variable - // to the layout to use for it. - Dictionary<String, VarLayout*> globalVarLayouts; - // Override the "maybe clone" logic so that we always clone virtual IRValue* maybeCloneValue(IRValue* originalVal) override; @@ -3434,18 +3441,31 @@ namespace Slang return newDeclRef; } - IRValue* cloneValue( IRSpecContextBase* context, IRValue* originalValue) { IRValue* clonedValue = nullptr; if (context->getClonedValues().TryGetValue(originalValue, clonedValue)) + { return clonedValue; + } return context->maybeCloneValue(originalValue); } + IRValue* maybeCloneValueWithMangledName( + IRSpecContextBase* context, + IRGlobalValue* originalValue) + { + for (auto gv = context->shared->module->firstGlobalValue; gv; gv = gv->nextGlobalValue) + { + if (gv->mangledName == originalValue->mangledName) + return gv; + } + return cloneValue(context, originalValue); + } + void cloneInst( IRSpecContextBase* context, IRBuilder* builder, @@ -3468,18 +3488,23 @@ namespace Slang context->maybeCloneType(originalInst->type), 0, nullptr, argCount, nullptr); - builder->addInst(clonedInst); registerClonedValue(context, clonedInst, originalInst); - - cloneDecorations(context, clonedInst, originalInst); - + auto oldBuilder = context->builder; + context->builder = builder; for (UInt aa = 0; aa < argCount; ++aa) { IRValue* originalArg = originalInst->getArg(aa); - IRValue* clonedArg = cloneValue(context, originalArg); - + IRValue* clonedArg; + if (originalArg->op == kIROp_witness_table) + clonedArg = cloneGlobalValueWithMangledName((IRSpecContext*)context, + ((IRGlobalValue*)originalArg)->mangledName, (IRGlobalValue*)originalArg); + else + clonedArg = cloneValue(context, originalArg); clonedInst->getArgs()[aa].init(clonedInst, clonedArg); } + builder->addInst(clonedInst); + context->builder = oldBuilder; + cloneDecorations(context, clonedInst, originalInst); } break; @@ -3524,12 +3549,15 @@ namespace Slang IRSpecContextBase* context, IRWitnessTable* originalTable, IROriginalValuesForClone const& originalValues, - IRWitnessTable* dstTable = nullptr) + IRWitnessTable* dstTable = nullptr, + bool registerValue = true) { auto clonedTable = dstTable ? dstTable : context->builder->createWitnessTable(); - registerClonedValue(context, clonedTable, originalValues); + if (registerValue) + registerClonedValue(context, clonedTable, originalValues); auto mangledName = originalTable->mangledName; + clonedTable->mangledName = mangledName; clonedTable->genericDecl = originalTable->genericDecl; clonedTable->subTypeDeclRef = originalTable->subTypeDeclRef; @@ -3539,8 +3567,11 @@ namespace Slang // Clone the entries in the witness table as well for( auto originalEntry : originalTable->entries ) { - auto clonedKey = context->maybeCloneValue(originalEntry->requirementKey.usedValue); - auto clonedVal = context->maybeCloneValue(originalEntry->satisfyingVal.usedValue); + auto clonedKey = cloneValue(context, originalEntry->requirementKey.usedValue); + + // if a global val with the mangled name already exists, don't clone again + auto clonedVal = maybeCloneValueWithMangledName(context, (IRGlobalValue*)(originalEntry->satisfyingVal.usedValue)); + /*auto clonedEntry = */context->builder->createWitnessTableEntry( clonedTable, clonedKey, @@ -3555,7 +3586,7 @@ namespace Slang IRWitnessTable* originalTable, IRWitnessTable* dstTable = nullptr) { - return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone(), dstTable); + return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone(), dstTable, false); } void cloneGlobalValueWithCodeCommon( @@ -3690,14 +3721,6 @@ namespace Slang // and their instructions. cloneFunctionCommon(context, clonedFunc, originalFunc); - // for now, clone all unreferenced witness tables - for (auto gv = context->getOriginalModule()->getFirstGlobalValue(); - gv; gv = gv->getNextValue()) - { - if (gv->op == kIROp_witness_table) - cloneGlobalValue(context, (IRWitnessTable*)gv); - } - // We need to attach the layout information for // the entry point to this declaration, so that // we can use it to inform downstream code emit. @@ -3746,7 +3769,7 @@ namespace Slang // TODO: We shouldn't be using strings for this. String getTargetName(IRSpecContext* context) { - switch( context->target ) + switch( context->shared->target ) { case CodeGenTarget::HLSL: return "hlsl"; @@ -4035,7 +4058,8 @@ namespace Slang IRSharedSpecContext* sharedContext, Session* session, IRModule* module, - IRModule* originalModule) + IRModule* originalModule, + CodeGenTarget target) { SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage; @@ -4053,7 +4077,7 @@ namespace Slang sharedContext->module = module; sharedContext->originalModule = originalModule; - + sharedContext->target = target; // We will populate a map with all of the IR values // that use the same mangled name, to make lookup easier // in other steps. @@ -4110,7 +4134,9 @@ namespace Slang sharedContext, compileRequest->mSession, nullptr, - originalIRModule); + originalIRModule, + target); + state->irModule = sharedContext->module; // We also need to attach the IR definitions for symbols from @@ -4123,7 +4149,6 @@ namespace Slang auto context = state->getContext(); context->shared = sharedContext; context->builder = &sharedContext->builderStorage; - context->target = target; // Create the GlobalGenericParamSubstitution for substituting global generic types // into user-provided type arguments @@ -4146,6 +4171,12 @@ namespace Slang context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); } + // for now, clone all unreferenced witness tables + for (auto sym :context->getSymbols()) + { + if (sym.Value->irGlobalValue->op == kIROp_witness_table) + cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); + } return state; } @@ -4263,7 +4294,31 @@ namespace Slang return symbol->irGlobalValue; } else - return nullptr; + { + // we don't have the required witness table yet, + // try to emit a specialize instruction to get one + auto subDeclRef = subtypeWitness->sub->AsDeclRefType(); + auto subDeclRefGen = DeclRef<Decl>(subDeclRef->declRef.decl, + createDefaultSubstitutions(context->builder->getSession(), subDeclRef->declRef.decl)); + + String genericName = getMangledNameForConformanceWitness( + subDeclRefGen, + subtypeWitness->sup); + if (context->getSymbols().TryGetValue(genericName, symbol)) + { + auto specInst = context->builder->emitSpecializeInst(subtypeWitness->sup, symbol->irGlobalValue, subDeclRef->declRef); + return specInst; + } + else + { + SLANG_UNEXPECTED("witness table not exist"); + UNREACHABLE_RETURN(nullptr); + } + } + } + else if (auto intVal = dynamic_cast<ConstantIntVal*>(val)) + { + return context->builder->getIntValue(context->shared->originalModule->session->getBuiltinType(BaseType::Int), intVal->value); } else if (auto proxyVal = dynamic_cast<IRProxyVal*>(val)) { @@ -4321,10 +4376,34 @@ namespace Slang return getIRValue(context, subst->args[argIndex]); } + else if (auto valDeclRef = declRef.As<GenericValueParamDecl>()) + { + // We have a constraint, but we need to find its index in the + // argument list of the substitutions. + UInt argIdx = 0; + bool found = false; + for (auto cd : genericDecl->Members) + { + if (cd.Ptr() == valDeclRef.getDecl()) + { + found = true; + break; + } + if (cd.As<GenericTypeParamDecl>()) + argIdx++; + else if (cd.As<GenericValueParamDecl>()) + argIdx++; + } + assert(found); + + assert(argIdx < subst->args.Count()); + + return getIRValue(context, subst->args[argIdx]); + } else { - SLANG_UNEXPECTED("unhandled case"); - return nullptr; + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(nullptr); } } @@ -4342,12 +4421,13 @@ namespace Slang // of the generic we are specializing, and in that case // we nee to translate it over to the equiavalent of // the `Val` we have been given. - if(declRef.getDecl()->ParentDecl == genSubst->genericDecl) + if(declRef.getDecl()->ParentDecl == genSubst->genericDecl && + (declRef.As<GenericTypeParamDecl>() || declRef.As<GenericValueParamDecl>()|| + declRef.As<GenericTypeConstraintDecl>())) { if (auto substVal = getSubstValue(this, declRef)) return substVal; } - int diff = 0; auto substDeclRef = declRefVal->declRef.SubstituteImpl(subst, &diff); if(!diff) @@ -4455,7 +4535,6 @@ namespace Slang // has already been made. To do that we will need to // compute the mangled name of the specialized function, // so that we can look for existing declarations. - String specMangledName; String specializedMangledName = getMangledNameForConformanceWitness(specDeclRef.Substitute(originalTable->subTypeDeclRef), specDeclRef.Substitute(originalTable->supTypeDeclRef)); @@ -4466,13 +4545,15 @@ namespace Slang // avoid it by building a dictionary ahead of time, // as is being done for the `IRSpecContext` used above. // We can probalby use the same basic context, actually. - auto module = originalTable->parentModule; - for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) + if (!dstTable) { - if (gv->mangledName == specMangledName) - return (IRWitnessTable*)gv; + auto module = sharedContext->module; + for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) + { + if (gv->mangledName == specializedMangledName) + return (IRWitnessTable*)gv; + } } - RefPtr<GenericSubstitution> newSubst = cloneSubstitutionsForSpecialization( sharedContext, specDeclRef.substitutions.genericSubstitutions, @@ -4483,13 +4564,12 @@ namespace Slang context.builder = &sharedContext->builderStorage; context.subst = specDeclRef.substitutions; context.subst.genericSubstitutions = newSubst; - // TODO: other initialization is needed here... auto specTable = cloneWitnessTableWithoutRegistering(&context, originalTable, dstTable); // Set up the clone to recognize that it is no longer generic - specTable->mangledName = specMangledName; + specTable->mangledName = specializedMangledName; specTable->genericDecl = nullptr; // Specialization of witness tables should trigger cascading specializations @@ -4499,8 +4579,9 @@ namespace Slang if (entry->satisfyingVal.usedValue->op == kIROp_Func) { IRFunc* func = (IRFunc*)entry->satisfyingVal.usedValue; - if (func->getGenericDecl()) - entry->satisfyingVal.set(getSpecializedFunc(sharedContext, func, specDeclRef)); + auto specFunc = getSpecializedFunc(sharedContext, func, specDeclRef); + entry->satisfyingVal.set(specFunc); + insertGlobalValueSymbol(sharedContext, specFunc); } } @@ -4526,13 +4607,16 @@ namespace Slang specMangledName = getMangledName(specDeclRef); else specMangledName = mangleSpecializedFuncName(genericFunc->mangledName, specDeclRef.substitutions); - + RefPtr<IRSpecSymbol> symb; + if (sharedContext->symbols.TryGetValue(specMangledName, symb)) + { + return (IRFunc*)(symb->irGlobalValue); + } // TODO: This is a terrible linear search, and we should // avoid it by building a dictionary ahead of time, // as is being done for the `IRSpecContext` used above. // We can probalby use the same basic context, actually. - auto module = genericFunc->parentModule; - for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) + for (auto gv = sharedContext->module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) { if (gv->mangledName == specMangledName) return (IRFunc*) gv; @@ -4639,7 +4723,8 @@ namespace Slang // are known, and specialize the callee based on those // known values. void specializeGenerics( - IRModule* module) + IRModule* module, + CodeGenTarget target) { IRSharedSpecContext sharedContextStorage; auto sharedContext = &sharedContextStorage; @@ -4648,7 +4733,8 @@ namespace Slang sharedContext, module->session, module, - module); + module, + target); // Our goal here is to find `specialize` instructions that // can be replaced with references to a suitably sepcialized @@ -4895,11 +4981,10 @@ namespace Slang table = findWitnessTableByName(genericWitnessTableName); SLANG_ASSERT(table); WitnessTableSpecializationWorkItem workItem; - workItem.srcTable = (IRWitnessTable*)table; + workItem.srcTable = (IRWitnessTable*)cloneGlobalValue(context, (IRWitnessTable*)(table)); workItem.dstTable = context->builder->createWitnessTable(); workItem.dstTable->mangledName = getMangledNameForConformanceWitness(subDeclRefType->declRef, subtypeWitness->sup); workItem.specDeclRef = subDeclRefType->declRef; - registerClonedValue(context, workItem.dstTable, workItem.srcTable); witnessTablesToSpecailize.Add(workItem); table = workItem.dstTable; } |
