diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2018-04-11 16:18:29 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2018-04-11 16:18:29 -0700 |
| commit | baf194e7456ba4568dcf11249896af35b3ce18cc (patch) | |
| tree | f75e20db450100d41bfa9c384a8bab0fdc28a749 /source/slang/ir.cpp | |
| parent | 6322983fa4dc84ef1e9dd8fad54d4c1580436e67 (diff) | |
Introduce an IR-level type system (#481)
* Introduce an IR-level type system
Up to this point, the Slang IR has used the front-end type system to represent types in the IR.
As a result (but ultimately more importantly) the IR representation of generics and specialization has used AST-level concepts embedded in the IR.
For example, to express the specialization of `vector<T,N>` to a concrete type `float` for `T`, we needed an IR operation that could represent the specialization, with operands that somehow represented the type argument `float`.
The whole thing was very complicated.
The big idea of this change is to introduce a new representation in which types in the IR are just ordinary instructions, so that using them as operands makes sense. The hierarchy of IR types closely mirrors the AST-side hierarchy for now, and that will probably be something we should maintain going forward.
In order to make these changes work, though, I also had to do major overhauls of things like the way substitutions are performed, how we check interface conformances, the way lookup through interface types is done, etc. etc. This is a big change, and unfortunately any attempt to summarize it in the commit message wouldn't do it justice.
* Fix 64-bit build warning
* Fix up some clang warnings/errors
Diffstat (limited to 'source/slang/ir.cpp')
| -rw-r--r-- | source/slang/ir.cpp | 3358 |
1 files changed, 1630 insertions, 1728 deletions
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 75f43453a..2615c1c07 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -14,38 +14,34 @@ namespace Slang Name* mangledName, IRGlobalValue* originalVal); - - static const IROpInfo kIROpInfos[] = + struct IROpMapEntry { -#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ - { #MNEMONIC, ARG_COUNT, FLAGS, }, -#include "ir-inst-defs.h" + IROp op; + IROpInfo info; }; - // - - IROp findIROp(char const* name) + // TODO: We should ideally be speeding up the name->inst + // mapping by using a dictionary, or even by pre-computing + // a hash table to be stored as a `static const` array. + static const IROpMapEntry kIROps[] = { - // TODO: need to make this faster by using a dictionary... - - static const struct { - char const* mnemonic; - IROp op; - } kOps[] = { + { kIROp_Invalid, { "invalid", 0, 0 } }, #define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ - { #MNEMONIC, kIROp_##ID }, - + { kIROp_##ID, { #MNEMONIC, ARG_COUNT, FLAGS, } }, #define PSEUDO_INST(ID) \ - { #ID, kIRPseudoOp_##ID }, - + { kIRPseudoOp_##ID, { #ID, 0, 0 } }, #include "ir-inst-defs.h" - }; + }; + + // - for (auto ee : kOps) + IROp findIROp(char const* name) + { + for (auto ee : kIROps) { - if (strcmp(name, ee.mnemonic) == 0) + if (strcmp(name, ee.info.name) == 0) return ee.op; } @@ -54,7 +50,13 @@ namespace Slang IROpInfo getIROpInfo(IROp op) { - return kIROpInfos[op]; + for (auto ee : kIROps) + { + if (ee.op == op) + return ee.info; + } + + return kIROps[0].info; } // @@ -65,7 +67,6 @@ namespace Slang auto uv = this->usedValue; if(!uv) { - assert(!user); assert(!nextUse); assert(!prevLink); return; @@ -160,6 +161,22 @@ namespace Slang return nullptr; } + // IRConstant + + IRIntegerValue GetIntVal(IRInst* inst) + { + switch (inst->op) + { + default: + SLANG_UNEXPECTED("needed a known integer value"); + UNREACHABLE_RETURN(0); + + case kIROp_IntLit: + return ((IRConstant*)inst)->u.intVal; + break; + } + } + // IRParam IRParam* IRParam::getNextParam() @@ -167,6 +184,17 @@ namespace Slang return as<IRParam>(getNextInst()); } + // IRArrayTypeBase + + IRInst* IRArrayTypeBase::getElementCount() + { + if (auto arrayType = as<IRArrayType>(this)) + return arrayType->getElementCount(); + + return nullptr; + } + + // IRBlock IRParam* IRBlock::getLastParam() @@ -416,13 +444,7 @@ namespace Slang return (IRBlock*)use->get(); } - // IRFunc - - IRType* IRFunc::getResultType() { return getType()->getResultType(); } - UInt IRFunc::getParamCount() { return getType()->getParamCount(); } - IRType* IRFunc::getParamType(UInt index) { return getType()->getParamType(index); } - - IRParam* IRFunc::getFirstParam() + IRParam* IRGlobalValueWithParams::getFirstParam() { auto entryBlock = getFirstBlock(); if(!entryBlock) return nullptr; @@ -430,6 +452,12 @@ namespace Slang return entryBlock->getFirstParam(); } + // IRFunc + + IRType* IRFunc::getResultType() { return getDataType()->getResultType(); } + UInt IRFunc::getParamCount() { return getDataType()->getParamCount(); } + IRType* IRFunc::getParamType(UInt index) { return getDataType()->getParamType(index); } + void IRGlobalValueWithCode::addBlock(IRBlock* block) { block->insertAtEnd(this); @@ -589,7 +617,7 @@ namespace Slang { if (rr == leftNonBlock) { - SLANG_ASSERT(!parentNonBlock); + SLANG_ASSERT(!parentNonBlock || parentNonBlock == leftNonBlock); parentNonBlock = rightNonBlock; break; } @@ -677,6 +705,9 @@ namespace Slang for (UInt ii = 0; ii < operandCount; ++ii) { auto operand = inst->getOperand(ii); + if (!operand) + continue; + auto operandParent = operand->getParent(); parent = mergeCandidateParentsForHoistableInst(parent, operandParent); @@ -727,22 +758,6 @@ namespace Slang value->sourceLoc = sourceLocInfo->sourceLoc; } - template<typename T> - static T* createValue( - IRBuilder* builder, - IROp op, - IRType* type) - { - assert(builder->getModule()); - T* value = (T*)builder->getModule()->memoryPool.allocZero(sizeof(T)); - new(value)T(); - value->op = op; - value->type = type; - builder->getModule()->irObjectsToFree.Add(value); - return value; - } - - // Create an IR instruction/value and initialize it. // // In this case `argCount` and `args` represnt the @@ -752,23 +767,39 @@ namespace Slang static T* createInstImpl( IRModule* module, IRBuilder* builder, - UInt size, IROp op, IRType* type, UInt fixedArgCount, IRInst* const* fixedArgs, - UInt varArgCount = 0, - IRInst* const* varArgs = nullptr) + UInt varArgListCount, + UInt const* listArgCounts, + IRInst* const* const* listArgs) { + UInt varArgCount = 0; + for (UInt ii = 0; ii < varArgListCount; ++ii) + { + varArgCount += listArgCounts[ii]; + } + + UInt size = sizeof(IRInst) + (fixedArgCount + varArgCount) * sizeof(IRUse); + if (sizeof(T) > size) + { + size = sizeof(T); + } + assert(module); T* inst = (T*)module->memoryPool.allocZero(size); new(inst)T(); + inst->operandCount = (uint32_t)(fixedArgCount + varArgCount); inst->op = op; - inst->type = type; + if (type) + { + inst->typeUse.init(inst, type); + } maybeSetSourceLoc(builder, inst); @@ -783,13 +814,21 @@ namespace Slang operand++; } - for( UInt aa = 0; aa < varArgCount; ++aa ) + for (UInt ii = 0; ii < varArgListCount; ++ii) { - if (varArgs) + UInt listArgCount = listArgCounts[ii]; + for (UInt jj = 0; jj < listArgCount; ++jj) { - operand->init(inst, varArgs[aa]); + if (listArgs[ii]) + { + operand->init(inst, listArgs[ii][jj]); + } + else + { + operand->init(inst, nullptr); + } + operand++; } - operand++; } module->irObjectsToFree.Add(inst); return inst; @@ -798,24 +837,46 @@ namespace Slang template<typename T> static T* createInstImpl( IRBuilder* builder, - UInt size, IROp op, IRType* type, UInt fixedArgCount, IRInst* const* fixedArgs, - UInt varArgCount = 0, + UInt varArgCount = 0, IRInst* const* varArgs = nullptr) { return createInstImpl<T>( builder->getModule(), builder, - size, op, type, fixedArgCount, fixedArgs, - varArgCount, - varArgs); + 1, + &varArgCount, + &varArgs); + } + + template<typename T> + static T* createInstImpl( + IRBuilder* builder, + IROp op, + IRType* type, + UInt fixedArgCount, + IRInst* const* fixedArgs, + UInt varArgListCount, + UInt const* listArgCount, + IRInst* const* const* listArgs) + { + return createInstImpl<T>( + builder->getModule(), + builder, + op, + type, + fixedArgCount, + fixedArgs, + varArgListCount, + listArgCount, + listArgs); } template<typename T> @@ -828,7 +889,6 @@ namespace Slang { return createInstImpl<T>( builder, - sizeof(T), op, type, argCount, @@ -843,7 +903,6 @@ namespace Slang { return createInstImpl<T>( builder, - sizeof(T), op, type, 0, @@ -859,7 +918,6 @@ namespace Slang { return createInstImpl<T>( builder, - sizeof(T), op, type, 1, @@ -877,7 +935,6 @@ namespace Slang IRInst* args[] = { arg1, arg2 }; return createInstImpl<T>( builder, - sizeof(T), op, type, 2, @@ -894,7 +951,6 @@ namespace Slang { return createInstImpl<T>( builder, - sizeof(T) + argCount * sizeof(IRUse), op, type, argCount, @@ -913,7 +969,6 @@ namespace Slang { return createInstImpl<T>( builder, - sizeof(T) + varArgCount * sizeof(IRUse), op, type, fixedArgCount, @@ -936,7 +991,6 @@ namespace Slang return createInstImpl<T>( builder, - sizeof(T) + varArgCount * sizeof(IRUse), op, type, fixedArgCount, @@ -949,7 +1003,7 @@ namespace Slang bool operator==(IRInstKey const& left, IRInstKey const& right) { if(left.inst->op != right.inst->op) return false; - if(left.inst->parent != right.inst->parent) return false; + if(left.inst->getFullType() != right.inst->getFullType()) return false; if(left.inst->operandCount != right.inst->operandCount) return false; auto argCount = left.inst->operandCount; @@ -967,7 +1021,7 @@ namespace Slang int IRInstKey::GetHashCode() { auto code = Slang::GetHashCode(inst->op); - code = combineHash(code, Slang::GetHashCode(inst->parent)); + code = combineHash(code, Slang::GetHashCode(inst->getFullType())); code = combineHash(code, Slang::GetHashCode(inst->getOperandCount())); auto argCount = inst->getOperandCount(); @@ -984,7 +1038,7 @@ namespace Slang bool operator==(IRConstantKey const& left, IRConstantKey const& right) { if(left.inst->op != right.inst->op) return false; - if(left.inst->type != right.inst->type) return false; + if(left.inst->getFullType() != right.inst->getFullType()) return false; if(left.inst->u.ptrData[0] != right.inst->u.ptrData[0]) return false; if(left.inst->u.ptrData[1] != right.inst->u.ptrData[1]) return false; return true; @@ -993,7 +1047,7 @@ namespace Slang int IRConstantKey::GetHashCode() { auto code = Slang::GetHashCode(inst->op); - code = combineHash(code, Slang::GetHashCode(inst->type)); + code = combineHash(code, Slang::GetHashCode(inst->getFullType())); code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[0])); code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[1])); return code; @@ -1009,7 +1063,7 @@ namespace Slang IRConstant keyInst; memset(&keyInst, 0, sizeof(keyInst)); keyInst.op = op; - keyInst.type = type; + keyInst.typeUse.usedValue = type; memcpy(&keyInst.u, value, valueSize); IRConstantKey key; @@ -1029,7 +1083,7 @@ namespace Slang // way: we will construct a temporary instruction and // then use it to look up in a cache of instructions. - irValue = createValue<IRConstant>(builder, op, type); + irValue = createInst<IRConstant>(builder, op, type); memcpy(&irValue->u, value, valueSize); key.inst = irValue; @@ -1049,7 +1103,7 @@ namespace Slang return findOrEmitConstant( this, kIROp_boolConst, - getSession()->getBoolType(), + getBoolType(), sizeof(value), &value); } @@ -1074,72 +1128,330 @@ namespace Slang &value); } - IRUndefined* IRBuilder::emitUndefined(IRType* type) + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + UInt operandListCount, + UInt const* listOperandCounts, + IRInst* const* const* listOperands) { - auto inst = createInst<IRUndefined>( - this, - kIROp_undefined, - type); + UInt operandCount = 0; + for (UInt ii = 0; ii < operandListCount; ++ii) + { + operandCount += listOperandCounts[ii]; + } + + // We are going to create a dummy instruction on the stack, + // which will be used as a key for lookup, so see if we + // already have an equivalent instruction available to use. + + size_t keySize = sizeof(IRInst) + operandCount * sizeof(IRUse); + IRInst* keyInst = (IRInst*) malloc(keySize); + memset(keyInst, 0, keySize); + + new(keyInst) IRInst(); + keyInst->op = op; + keyInst->typeUse.usedValue = type; + keyInst->operandCount = (uint32_t) operandCount; + + IRUse* operand = keyInst->getOperands(); + for (UInt ii = 0; ii < operandListCount; ++ii) + { + UInt listOperandCount = listOperandCounts[ii]; + for (UInt jj = 0; jj < listOperandCount; ++jj) + { + operand->usedValue = listOperands[ii][jj]; + operand++; + } + } + + IRInstKey key; + key.inst = keyInst; + + IRInst* foundInst = nullptr; + bool found = builder->sharedBuilder->globalValueNumberingMap.TryGetValue(key, foundInst); + + free((void*)keyInst); + + if (found) + { + return foundInst; + } + + // If no instruction was found, then we need to emit it. + + IRInst* inst = createInstImpl<IRInst>( + builder, + op, + type, + 0, + nullptr, + operandListCount, + listOperandCounts, + listOperands); + addHoistableInst(builder, inst); + + key.inst = inst; + builder->sharedBuilder->globalValueNumberingMap.Add(key, inst); - addInst(inst); - return inst; } - IRInst* IRBuilder::getDeclRefVal( - DeclRefBase const& declRef) + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + UInt operandCount, + IRInst* const* operands) + { + return findOrEmitHoistableInst( + builder, + type, + op, + 1, + &operandCount, + &operands); + } + + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + IRInst* operand, + UInt operandCount, + IRInst* const* operands) + { + UInt counts[] = { 1, operandCount }; + IRInst* const* lists[] = { &operand, operands }; + + return findOrEmitHoistableInst( + builder, + type, + op, + 2, + counts, + lists); + } + + + IRType* IRBuilder::getType( + IROp op, + UInt operandCount, + IRInst* const* operands) { - // TODO: we should cache these... - auto irValue = createValue<IRDeclRef>( + return (IRType*) findOrEmitHoistableInst( this, - kIROp_decl_ref, - nullptr); - irValue->declRef = DeclRef<Decl>(declRef.decl, declRef.substitutions); + nullptr, + op, + operandCount, + operands); + } - addHoistableInst(this, irValue); + IRType* IRBuilder::getType( + IROp op) + { + return getType(op, 0, nullptr); + } - return irValue; + IRBasicType* IRBuilder::getBasicType(BaseType baseType) + { + return (IRBasicType*)getType( + IROp((UInt)kIROp_FirstBasicType + (UInt)baseType)); + } + + IRBasicType* IRBuilder::getVoidType() + { + return (IRVoidType*)getType(kIROp_VoidType); + } + + IRBasicType* IRBuilder::getBoolType() + { + return (IRBoolType*)getType(kIROp_BoolType); + } + + IRBasicType* IRBuilder::getIntType() + { + return (IRBasicType*)getType(kIROp_IntType); + } + + IRBasicBlockType* IRBuilder::getBasicBlockType() + { + return (IRBasicBlockType*)getType(kIROp_BasicBlockType); + } + + IRTypeKind* IRBuilder::getTypeKind() + { + return (IRTypeKind*)getType(kIROp_TypeKind); + } + + IRGenericKind* IRBuilder::getGenericKind() + { + return (IRGenericKind*)getType(kIROp_GenericKind); + } + + IRPtrType* IRBuilder::getPtrType(IRType* valueType) + { + return (IRPtrType*) getPtrType(kIROp_PtrType, valueType); + } + + IROutType* IRBuilder::getOutType(IRType* valueType) + { + return (IROutType*) getPtrType(kIROp_OutType, valueType); + } + + IRInOutType* IRBuilder::getInOutType(IRType* valueType) + { + return (IRInOutType*) getPtrType(kIROp_InOutType, valueType); + } + + IRPtrTypeBase* IRBuilder::getPtrType(IROp op, IRType* valueType) + { + IRInst* operands[] = { valueType }; + return (IRPtrTypeBase*) getType( + op, + 1, + operands); + } + + IRArrayTypeBase* IRBuilder::getArrayTypeBase( + IROp op, + IRType* elementType, + IRInst* elementCount) + { + IRInst* operands[] = { elementType, elementCount }; + return (IRArrayTypeBase*)getType( + op, + op == kIROp_ArrayType ? 2 : 1, + operands); } - IRInst* IRBuilder::getTypeVal(IRType * type) + IRArrayType* IRBuilder::getArrayType( + IRType* elementType, + IRInst* elementCount) { - auto irValue = createValue<IRInst>( + IRInst* operands[] = { elementType, elementCount }; + return (IRArrayType*)getType( + kIROp_ArrayType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRUnsizedArrayType* IRBuilder::getUnsizedArrayType( + IRType* elementType) + { + IRInst* operands[] = { elementType }; + return (IRUnsizedArrayType*)getType( + kIROp_UnsizedArrayType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRVectorType* IRBuilder::getVectorType( + IRType* elementType, + IRInst* elementCount) + { + IRInst* operands[] = { elementType, elementCount }; + return (IRVectorType*)getType( + kIROp_VectorType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRMatrixType* IRBuilder::getMatrixType( + IRType* elementType, + IRInst* rowCount, + IRInst* columnCount) + { + IRInst* operands[] = { elementType, rowCount, columnCount }; + return (IRMatrixType*)getType( + kIROp_MatrixType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRFuncType* IRBuilder::getFuncType( + UInt paramCount, + IRType* const* paramTypes, + IRType* resultType) + { + return (IRFuncType*) findOrEmitHoistableInst( this, - kIROp_TypeType, - nullptr); - irValue->type = type; - if (auto typetype = dynamic_cast<TypeType*>(type)) - irValue->type = typetype->type; - return irValue; + nullptr, + kIROp_FuncType, + resultType, + paramCount, + (IRInst* const*) paramTypes); } - IRInst* IRBuilder::emitSpecializeInst( - Type* type, - IRInst* genericVal, - IRInst* specDeclRef) + IRConstExprRate* IRBuilder::getConstExprRate() + { + return (IRConstExprRate*)getType(kIROp_ConstExprRate); + } + + IRGroupSharedRate* IRBuilder::getGroupSharedRate() + { + return (IRGroupSharedRate*)getType(kIROp_GroupSharedRate); + } + + IRRateQualifiedType* IRBuilder::getRateQualifiedType( + IRRate* rate, + IRType* dataType) + { + IRInst* operands[] = { rate, dataType }; + return (IRRateQualifiedType*)getType( + kIROp_RateQualifiedType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + void IRBuilder::setDataType(IRInst* inst, IRType* dataType) + { + if (auto oldRateQualifiedType = as<IRRateQualifiedType>(inst->getFullType())) + { + // Construct a new rate-qualified type using the same rate. + + auto newRateQualifiedType = getRateQualifiedType( + oldRateQualifiedType->getRate(), + dataType); + + inst->setFullType(newRateQualifiedType); + } + else + { + // No rate? Just clobber the data type. + inst->setFullType(dataType); + } + } + + + IRUndefined* IRBuilder::emitUndefined(IRType* type) { - auto inst = createInst<IRSpecialize>( + auto inst = createInst<IRUndefined>( this, - kIROp_specialize, - type, - genericVal, - specDeclRef); + kIROp_undefined, + type); + addInst(inst); + return inst; } IRInst* IRBuilder::emitSpecializeInst( - Type* type, + IRType* type, IRInst* genericVal, - DeclRef<Decl> specDeclRef) + UInt argCount, + IRInst* const* args) { - auto specDeclRefVal = getDeclRefVal(specDeclRef); - auto inst = createInst<IRSpecialize>( + auto inst = createInstWithTrailingArgs<IRSpecialize>( this, - kIROp_specialize, + kIROp_Specialize, type, - genericVal, - specDeclRefVal); + 1, + &genericVal, + argCount, + args); + addInst(inst); return inst; } @@ -1155,45 +1467,7 @@ namespace Slang type, witnessTableVal, interfaceMethodVal); - addInst(inst); - return inst; - } - IRInst* IRBuilder::emitLookupInterfaceMethodInst( - IRType* type, - DeclRef<Decl> witnessTableDeclRef, - DeclRef<Decl> interfaceMethodDeclRef) - { - auto witnessTableVal = getDeclRefVal(witnessTableDeclRef); - DeclRef<Decl> removeSubstDeclRef = interfaceMethodDeclRef; - removeSubstDeclRef.substitutions = SubstitutionSet(); - auto interfaceMethodVal = getDeclRefVal(removeSubstDeclRef); - return emitLookupInterfaceMethodInst(type, witnessTableVal, interfaceMethodVal); - } - - IRInst* IRBuilder::emitLookupInterfaceMethodInst( - IRType* type, - IRInst* witnessTableVal, - DeclRef<Decl> interfaceMethodDeclRef) - { - DeclRef<Decl> removeSubstDeclRef = interfaceMethodDeclRef; - removeSubstDeclRef.substitutions = SubstitutionSet(); - auto interfaceMethodVal = getDeclRefVal(removeSubstDeclRef); - return emitLookupInterfaceMethodInst(type, witnessTableVal, interfaceMethodVal); - } - - IRInst* IRBuilder::emitFindWitnessTable( - DeclRef<Decl> baseTypeDeclRef, - IRType* interfaceType) - { - auto interfaceTypeDeclRef = interfaceType->AsDeclRefType(); - SLANG_ASSERT(interfaceTypeDeclRef); - auto inst = createInst<IRLookupWitnessTable>( - this, - kIROp_lookup_witness_table, - interfaceType, - getDeclRefVal(baseTypeDeclRef), - getDeclRefVal(interfaceTypeDeclRef->declRef)); addInst(inst); return inst; } @@ -1279,10 +1553,12 @@ namespace Slang auto moduleInst = createInstImpl<IRModuleInst>( module, this, - sizeof(IRModuleInst), kIROp_Module, nullptr, 0, + nullptr, + 0, + nullptr, nullptr); module->moduleInst = moduleInst; @@ -1290,58 +1566,103 @@ namespace Slang } void addGlobalValue( - IRModule* module, + IRBuilder* builder, IRGlobalValue* value) { - if(!module) - return; + // Try to find a suitable parent for the + // global value we are emitting. + // + // We will start out search at the current + // parent instruction for the builder, and + // possibly work our way up. + // + auto parent = builder->insertIntoParent; + while(parent) + { + // Inserting into the top level of a module? + // That is fine, and we can stop searching. + if (as<IRModuleInst>(parent)) + break; - value->insertAtEnd(module->moduleInst); + // Inserting into a basic block inside of + // a generic? That is okay too. + if (auto block = as<IRBlock>(parent)) + { + if (as<IRGeneric>(block->parent)) + break; + } + + // Otherwise, move up the chain. + parent = parent->parent; + } + + // If we somehow ran out of parents (possibly + // because an instruction wasn't linked into + // the full hierarchy yet), then we will + // fall back to inserting into the overall module. + if (!parent) + { + parent = builder->getModule()->getModuleInst(); + } + + // If it turns out that we are inserting into the + // current "insert into" parent for the builder, then + // we need to respect its "insert before" setting + // as well. + if (parent == builder->insertIntoParent + && builder->insertBeforeInst) + { + value->insertBefore(builder->insertBeforeInst); + } + else + { + value->insertAtEnd(parent); + } } IRFunc* IRBuilder::createFunc() { - IRFunc* rsFunc = createValue<IRFunc>( + IRFunc* rsFunc = createInst<IRFunc>( this, kIROp_Func, nullptr); maybeSetSourceLoc(this, rsFunc); - addGlobalValue(getModule(), rsFunc); + addGlobalValue(this, rsFunc); return rsFunc; } IRGlobalVar* IRBuilder::createGlobalVar( IRType* valueType) { - auto ptrType = getSession()->getPtrType(valueType); - IRGlobalVar* globalVar = createValue<IRGlobalVar>( + auto ptrType = getPtrType(valueType); + IRGlobalVar* globalVar = createInst<IRGlobalVar>( this, - kIROp_global_var, + kIROp_GlobalVar, ptrType); maybeSetSourceLoc(this, globalVar); - addGlobalValue(getModule(), globalVar); + addGlobalValue(this, globalVar); return globalVar; } IRGlobalConstant* IRBuilder::createGlobalConstant( IRType* valueType) { - IRGlobalConstant* globalConstant = createValue<IRGlobalConstant>( + IRGlobalConstant* globalConstant = createInst<IRGlobalConstant>( this, - kIROp_global_constant, + kIROp_GlobalConstant, valueType); maybeSetSourceLoc(this, globalConstant); - addGlobalValue(getModule(), globalConstant); + addGlobalValue(this, globalConstant); return globalConstant; } IRWitnessTable* IRBuilder::createWitnessTable() { - IRWitnessTable* witnessTable = createValue<IRWitnessTable>( + IRWitnessTable* witnessTable = createInst<IRWitnessTable>( this, - kIROp_witness_table, + kIROp_WitnessTable, nullptr); - addGlobalValue(getModule(), witnessTable); + addGlobalValue(this, witnessTable); return witnessTable; } @@ -1352,7 +1673,7 @@ namespace Slang { IRWitnessTableEntry* entry = createInst<IRWitnessTableEntry>( this, - kIROp_witness_table_entry, + kIROp_WitnessTableEntry, nullptr, requirementKey, satisfyingVal); @@ -1365,6 +1686,68 @@ namespace Slang return entry; } + IRStructType* IRBuilder::createStructType() + { + IRStructType* structType = createInst<IRStructType>( + this, + kIROp_StructType, + nullptr); + addGlobalValue(this, structType); + return structType; + } + + IRStructKey* IRBuilder::createStructKey() + { + IRStructKey* structKey = createInst<IRStructKey>( + this, + kIROp_StructKey, + nullptr); + addGlobalValue(this, structKey); + return structKey; + } + + // Create a field nested in a struct type, declaring that + // the specified field key maps to a field with the specified type. + IRStructField* IRBuilder::createStructField( + IRStructType* structType, + IRStructKey* fieldKey, + IRType* fieldType) + { + IRInst* operands[] = { fieldKey, fieldType }; + IRStructField* field = (IRStructField*) createInstWithTrailingArgs<IRInst>( + this, + kIROp_StructField, + nullptr, + 0, + nullptr, + 2, + operands); + + if (structType) + { + field->insertAtEnd(structType); + } + + return field; + } + + IRGeneric* IRBuilder::createGeneric() + { + IRGeneric* irGeneric = createInst<IRGeneric>( + this, + kIROp_Generic, + nullptr); + return irGeneric; + } + + IRGeneric* IRBuilder::emitGeneric() + { + auto irGeneric = createGeneric(); + addGlobalValue(this, irGeneric); + return irGeneric; + } + + IRWitnessTable * IRBuilder::lookupWitnessTable(Name* mangledName) { IRWitnessTable * result; @@ -1381,10 +1764,10 @@ namespace Slang IRBlock* IRBuilder::createBlock() { - return createValue<IRBlock>( + return createInst<IRBlock>( this, kIROp_Block, - getSession()->getIRBasicBlockType()); + getBasicBlockType()); } IRBlock* IRBuilder::emitBlock() @@ -1409,7 +1792,7 @@ namespace Slang IRParam* IRBuilder::createParam( IRType* type) { - auto param = createValue<IRParam>( + auto param = createInst<IRParam>( this, kIROp_Param, type); @@ -1430,7 +1813,7 @@ namespace Slang IRVar* IRBuilder::emitVar( IRType* type) { - auto allocatedType = getSession()->getPtrType(type); + auto allocatedType = getPtrType(type); auto inst = createInst<IRVar>( this, kIROp_Var, @@ -1449,12 +1832,12 @@ namespace Slang // results) at the "default" rate of the parent function, // unless a subsequent analysis pass constraints it. - RefPtr<Type> valueType; - if(auto ptrType = ptr->getDataType()->As<PtrTypeBase>()) + IRType* valueType = nullptr; + if(auto ptrType = as<IRPtrTypeBase>(ptr->getDataType())) { valueType = ptrType->getValueType(); } - else if(auto ptrLikeType = ptr->getDataType()->As<PointerLikeType>()) + else if(auto ptrLikeType = as<IRPointerLikeType>(ptr->getDataType())) { valueType = ptrLikeType->getElementType(); } @@ -1465,15 +1848,20 @@ namespace Slang return nullptr; } - // Ugly special case: the result of loading from `groupshared` - // memory should not itself be `groupshared`. + // Ugly special case: if the front-end created a variable with + // type `Ptr<@R T>` instead of `@R Ptr<T>`, then the above + // logic will yield `@R T` instead of `T`, and we need to + // try and fix that up here. + // + // TODO: Lowering to the IR should be fixed to never create + // that case: rate-qualified types should only be allowed + // to appear as the type of an instruction, and should not + // be allowed as operands to type constructors (except + // in special cases we decide to allow). // - // TODO: This special case will go away once `GroupSharedType` - // is replaced by a `GroupSharedRate` that gets used together - // with `RateQualifiedType`. - if(auto rateType = valueType->As<GroupSharedType>()) + if(auto rateType = as<IRRateQualifiedType>(valueType)) { - valueType = rateType->valueType; + valueType = rateType->getValueType(); } auto inst = createInst<IRLoad>( @@ -1589,7 +1977,7 @@ namespace Slang UInt elementCount, UInt const* elementIndices) { - auto intType = getSession()->getBuiltinType(BaseType::Int); + auto intType = getBasicType(BaseType::Int); IRInst* irElementIndices[4]; for (UInt ii = 0; ii < elementCount; ++ii) @@ -1631,7 +2019,7 @@ namespace Slang UInt elementCount, UInt const* elementIndices) { - auto intType = getSession()->getBuiltinType(BaseType::Int); + auto intType = getBasicType(BaseType::Int); IRInst* irElementIndices[4]; for (UInt ii = 0; ii < elementCount; ++ii) @@ -1802,6 +2190,30 @@ namespace Slang return inst; } + IRGlobalGenericParam* IRBuilder::emitGlobalGenericParam() + { + IRGlobalGenericParam* irGenericParam = createInst<IRGlobalGenericParam>( + this, + kIROp_GlobalGenericParam, + nullptr); + addGlobalValue(this, irGenericParam); + return irGenericParam; + } + + IRBindGlobalGenericParam* IRBuilder::emitBindGlobalGenericParam( + IRInst* param, + IRInst* val) + { + auto inst = createInst<IRBindGlobalGenericParam>( + this, + kIROp_BindGlobalGenericParam, + nullptr, + param, + val); + addInst(inst); + return inst; + } + IRHighLevelDeclDecoration* IRBuilder::addHighLevelDeclDecoration(IRInst* inst, Decl* decl) { auto decoration = addDecoration<IRHighLevelDeclDecoration>(inst, kIRDecorationOp_HighLevelDecl); @@ -1873,6 +2285,11 @@ namespace Slang bool opHasResult(IRInst* inst); + bool instHasUses(IRInst* inst) + { + return inst->firstUse != nullptr; + } + static UInt getID( IRDumpContext* context, IRInst* value) @@ -1881,7 +2298,7 @@ namespace Slang if (context->mapValueToID.TryGetValue(value, id)) return id; - if (opHasResult(value)) + if (opHasResult(value) || instHasUses(value)) { id = context->idCounter++; } @@ -1900,33 +2317,30 @@ namespace Slang return; } - switch(inst->op) + if (auto globalValue = as<IRGlobalValue>(inst)) { - case kIROp_Func: - case kIROp_global_var: - case kIROp_global_constant: - case kIROp_witness_table: + auto mangledName = globalValue->mangledName; + if(mangledName) { - auto irFunc = (IRFunc*) inst; - dump(context, "@"); - dump(context, getText(irFunc->mangledName).Buffer()); - } - break; - - default: - { - UInt id = getID(context, inst); - if (id) + auto mangledNameText = getText(mangledName); + if (mangledNameText.Length() > 0) { - dump(context, "%"); - dump(context, id); - } - else - { - dump(context, "_"); + dump(context, "@"); + dump(context, mangledNameText.Buffer()); + return; } } - break; + } + + UInt id = getID(context, inst); + if (id) + { + dump(context, "%"); + dump(context, id); + } + else + { + dump(context, "_"); } } @@ -1945,7 +2359,7 @@ namespace Slang // TODO: we should have a dedicated value for the `undef` case if (!inst) { - dump(context, "undef"); + dumpID(context, inst); return; } @@ -1963,16 +2377,6 @@ namespace Slang dump(context, ((IRConstant*)inst)->u.intVal ? "true" : "false"); return; - case kIROp_TypeType: - dumpType(context, (IRType*)inst); - return; - - case kIROp_decl_ref: - dump(context, "$\""); - dumpDeclRef(context, ((IRDeclRef*)inst)->declRef); - dump(context, "\""); - return; - default: break; } @@ -1980,123 +2384,6 @@ namespace Slang dumpID(context, inst); } - static void dump( - IRDumpContext* context, - Name* name) - { - dump(context, getText(name).Buffer()); - } - - static void dumpVal( - IRDumpContext* context, - Val* val) - { - if(auto type = dynamic_cast<Type*>(val)) - { - dumpType(context, type); - } - else if(auto constIntVal = dynamic_cast<ConstantIntVal*>(val)) - { - dump(context, constIntVal->value); - } - else if(auto genericParamVal = dynamic_cast<GenericParamIntVal*>(val)) - { - dumpDeclRef(context, genericParamVal->declRef); - } - else if(auto declaredSubtypeWitness = dynamic_cast<DeclaredSubtypeWitness*>(val)) - { - dump(context, "DeclaredSubtypeWitness("); - dumpType(context, declaredSubtypeWitness->sub); - dump(context, ", "); - dumpType(context, declaredSubtypeWitness->sup); - dump(context, ", "); - dumpDeclRef(context, declaredSubtypeWitness->declRef); - dump(context, ")"); - } - else if (auto proxyVal = dynamic_cast<IRProxyVal*>(val)) - { - dumpOperand(context, proxyVal->inst.get()); - } - else - { - dump(context, "???"); - } - } - - static void dumpDeclRef( - IRDumpContext* context, - DeclRef<Decl> const& declRef) - { - auto decl = declRef.getDecl(); - - auto parentDeclRef = declRef.GetParent(); - auto genericParentDeclRef = parentDeclRef.As<GenericDecl>(); - if (genericParentDeclRef) - { - if (genericParentDeclRef.getDecl()->inner.Ptr() == decl) - { - parentDeclRef = genericParentDeclRef.GetParent(); - } - else - { - genericParentDeclRef = DeclRef<GenericDecl>(); - } - } - - if(parentDeclRef.As<ModuleDecl>()) - { - parentDeclRef = DeclRef<ContainerDecl>(); - } - else if(parentDeclRef.As<GenericDecl>()) - { - parentDeclRef = DeclRef<ContainerDecl>(); - } - - if(parentDeclRef) - { - dumpDeclRef(context, parentDeclRef); - dump(context, "."); - } - dump(context, decl->getName()); - if (auto genericTypeConstraintDecl = dynamic_cast<GenericTypeConstraintDecl*>(decl)) - { - dump(context, "{"); - dumpType(context, genericTypeConstraintDecl->sub); - dump(context, " : "); - dumpType(context, genericTypeConstraintDecl->sup); - dump(context, "}"); - } - else if (auto inheritanceDecl = dynamic_cast<InheritanceDecl*>(decl)) - { - dump(context, "{ _ : "); - dumpType(context, inheritanceDecl->base); - dump(context, "}"); - } - - if(genericParentDeclRef) - { - auto subst = declRef.substitutions.genericSubstitutions; - if( !subst || subst->genericDecl != genericParentDeclRef.getDecl() ) - { - // No actual substitutions in place here - dump(context, "<>"); - } - else - { - auto args = subst->args; - bool first = true; - dump(context, "<"); - for(auto aa : args) - { - if(!first) dump(context, ","); - dumpVal(context, aa); - first = false; - } - dump(context, ">"); - } - } - } - static void dumpType( IRDumpContext* context, IRType* type) @@ -2107,84 +2394,10 @@ namespace Slang return; } - if(auto funcType = type->As<FuncType>()) - { - UInt paramCount = funcType->getParamCount(); - dump(context, "("); - for( UInt pp = 0; pp < paramCount; ++pp ) - { - if(pp != 0) dump(context, ", "); - dumpType(context, funcType->getParamType(pp)); - } - dump(context, ") -> "); - dumpType(context, funcType->getResultType()); - } - else if(auto arrayType = type->As<ArrayExpressionType>()) - { - dumpType(context, arrayType->baseType); - dump(context, "["); - if(auto elementCount = arrayType->ArrayLength) - { - dumpVal(context, elementCount); - } - dump(context, "]"); - } - else if(auto declRefType = type->As<DeclRefType>()) - { - dumpDeclRef(context, declRefType->declRef); - } - else if(auto groupSharedType = type->As<GroupSharedType>()) - { - dump(context, "@ThreadGroup "); - dumpType(context, groupSharedType->valueType); - } - else if(auto rateQualifiedType = type->As<RateQualifiedType>()) - { - dump(context, "@"); - dumpType(context, rateQualifiedType->rate); - dump(context, " "); - dumpType(context, rateQualifiedType->valueType); - } - else if(auto constExprRate = type->As<ConstExprRate>()) - { - dump(context, "ConstExpr"); - } - else - { - // Need a default case here - dump(context, "???"); - } - -#if 0 - auto op = type->op; - auto opInfo = kIROpInfos[op]; - - switch (op) - { - case kIROp_StructType: - dumpID(context, type); - break; - - default: - { - dump(context, opInfo.name); - UInt argCount = type->getArgCount(); - - if (argCount > 1) - { - dump(context, "<"); - for (UInt aa = 1; aa < argCount; ++aa) - { - if (aa != 1) dump(context, ","); - dumpOperand(context, type->getArg(aa)); - - } - dump(context, ">"); - } - } - break; - } -#endif + // TODO: we should consider some special-case printing + // for types, so that the IR doesn't get too hard to read + // (always having to back-reference for what a type expands to) + dumpOperand(context, type); } static void dumpInstTypeClause( @@ -2245,60 +2458,11 @@ namespace Slang } } - void dumpGenericSignature( + void dumpIRDecorations( IRDumpContext* context, - GenericDecl* genericDecl) - { - for( auto pp = genericDecl->ParentDecl; pp; pp = pp->ParentDecl ) - { - if( auto genericAncestor = dynamic_cast<GenericDecl*>(pp) ) - { - dumpGenericSignature(context, genericAncestor); - break; - } - } - - dump(context, " <"); - bool first = true; - for (auto mm : genericDecl->Members) - { - - if( auto typeParamDecl = mm.As<GenericTypeParamDecl>() ) - { - if (!first) dump(context, ", "); - dumpDeclRef(context, makeDeclRef(typeParamDecl.Ptr())); - first = false; - } - else if( auto valueParamDecl = mm.As<GenericTypeParamDecl>() ) - { - if (!first) dump(context, ", "); - dumpDeclRef(context, makeDeclRef(valueParamDecl.Ptr())); - first = false; - } - } - first = true; - for (auto mm : genericDecl->Members) - { - if( auto constraintDecl = mm.As<GenericTypeConstraintDecl>() ) - { - if (!first) dump(context, ", "); - else dump(context, " where "); - - dumpType(context, constraintDecl->sub); - dump(context, " : "); - dumpType(context, constraintDecl->sup); - first = false; - } - } - dump(context, ">"); - } - - void dumpIRFunc( - IRDumpContext* context, - IRFunc* func) + IRInst* inst) { - - for( auto dd = func->firstDecoration; dd; dd = dd->next ) + for( auto dd = inst->firstDecoration; dd; dd = dd->next ) { switch( dd->op ) { @@ -2316,21 +2480,26 @@ namespace Slang } } + } + + void dumpIRGlobalValueWithCode( + IRDumpContext* context, + IRGlobalValueWithCode* code) + { + // TODO: should apply this to all instructions + dumpIRDecorations(context, code); + + auto opInfo = getIROpInfo(code->op); dump(context, "\n"); dumpIndent(context); - dump(context, "ir_func "); - dumpID(context, func); + dump(context, opInfo.name); + dump(context, " "); + dumpID(context, code); - if (func->getGenericDecl()) - { - dump(context, " "); - dumpGenericSignature(context, func->getGenericDecl()); - } + dumpInstTypeClause(context, code->getFullType()); - dumpInstTypeClause(context, func->getType()); - - if (!func->getFirstBlock()) + if (!code->getFirstBlock()) { // Just a declaration. dump(context, ";\n"); @@ -2343,9 +2512,9 @@ namespace Slang dump(context, "{\n"); context->indent++; - for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock()) + for (auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock()) { - if (bb != func->getFirstBlock()) + if (bb != code->getFirstBlock()) dump(context, "\n"); dumpBlock(context, bb); } @@ -2360,57 +2529,64 @@ namespace Slang IRDumpContext dumpContext; StringBuilder sbDump; dumpContext.builder = &sbDump; - dumpIRFunc(&dumpContext, func); + dumpIRGlobalValueWithCode(&dumpContext, func); auto strFunc = sbDump.ToString(); return strFunc; } - void dumpIRGlobalVar( + void dumpIRWitnessTableEntry( + IRDumpContext* context, + IRWitnessTableEntry* entry) + { + dump(context, "witness_table_entry("); + dumpOperand(context, entry->requirementKey.get()); + dump(context, ","); + dumpOperand(context, entry->satisfyingVal.get()); + dump(context, ")\n"); + } + + void dumpIRParentInst( IRDumpContext* context, - IRGlobalVar* var) + IRParentInst* inst) { + // TODO: should apply this to all instructions + dumpIRDecorations(context, inst); + + auto opInfo = getIROpInfo(inst->op); + dump(context, "\n"); dumpIndent(context); - dump(context, "ir_global_var "); - dumpID(context, var); - dumpInstTypeClause(context, var->getFullType()); + dump(context, opInfo.name); + dump(context, " "); + dumpID(context, inst); - // TODO: deal with the case where a global - // might have embedded initialization logic. + dumpInstTypeClause(context, inst->getFullType()); - dump(context, ";\n"); - } + if (!inst->getFirstChild()) + { + // Empty. + dump(context, ";\n"); + return; + } - void dumpIRGlobalConstant( - IRDumpContext* context, - IRGlobalConstant* val) - { dump(context, "\n"); - dumpIndent(context); - dump(context, "ir_global_constant "); - dumpID(context, val); - dumpInstTypeClause(context, val->getFullType()); - // TODO: deal with the case where a global - // might have embedded initialization logic. + dumpIndent(context); + dump(context, "{\n"); + context->indent++; - dump(context, ";\n"); - } + for (auto child = inst->getFirstChild(); child; child = child->getNextInst()) + { + dumpInst(context, child); + } - void dumpIRWitnessTableEntry( - IRDumpContext* context, - IRWitnessTableEntry* entry) - { - dump(context, "witness_table_entry("); - dumpOperand(context, entry->requirementKey.get()); - dump(context, ","); - dumpOperand(context, entry->satisfyingVal.get()); - dump(context, ")\n"); + context->indent--; + dump(context, "}\n"); } - void dumpIRWitnessTable( + void dumpIRGeneric( IRDumpContext* context, - IRWitnessTable* witnessTable) + IRGeneric* witnessTable) { dump(context, "\n"); dumpIndent(context); @@ -2447,22 +2623,18 @@ namespace Slang switch (op) { case kIROp_Func: - dumpIRFunc(context, (IRFunc*)inst); - return; - - case kIROp_global_var: - dumpIRGlobalVar(context, (IRGlobalVar*)inst); - return; - - case kIROp_global_constant: - dumpIRGlobalConstant(context, (IRGlobalConstant*)inst); + case kIROp_GlobalVar: + case kIROp_GlobalConstant: + case kIROp_Generic: + dumpIRGlobalValueWithCode(context, (IRGlobalValueWithCode*)inst); return; - case kIROp_witness_table: - dumpIRWitnessTable(context, (IRWitnessTable*)inst); + case kIROp_WitnessTable: + case kIROp_StructType: + dumpIRParentInst(context, (IRWitnessTable*)inst); return; - case kIROp_witness_table_entry: + case kIROp_WitnessTableEntry: dumpIRWitnessTableEntry(context, (IRWitnessTableEntry*)inst); return; @@ -2473,31 +2645,30 @@ namespace Slang // Okay, we have a seemingly "ordinary" op now dumpIndent(context); - auto opInfo = &kIROpInfos[op]; - auto type = inst->getFullType(); + auto opInfo = getIROpInfo(op); auto dataType = inst->getDataType(); + auto rate = inst->getRate(); - if (!dataType) + if(rate) { - // No result, okay... + dump(context, "@"); + dumpOperand(context, rate); + dump(context, " "); + } + + if(opHasResult(inst) || instHasUses(inst)) + { + dump(context, "let "); + dumpID(context, inst); + dumpInstTypeClause(context, dataType); + dump(context, "\t= "); } else { - auto basicType = dataType->As<BasicExpressionType>(); - if (basicType && basicType->baseType == BaseType::Void) - { - // No result, okay... - } - else - { - dump(context, "let "); - dumpID(context, inst); - dumpInstTypeClause(context, type); - dump(context, "\t= "); - } + // No result, okay... } - dump(context, opInfo->name); + dump(context, opInfo.name); UInt argCount = inst->getOperandCount(); UInt ii = 0; @@ -2531,7 +2702,6 @@ namespace Slang case kIROp_IntLit: case kIROp_FloatLit: case kIROp_boolConst: - case kIROp_decl_ref: dumpOperand(context, inst); break; @@ -2596,24 +2766,29 @@ namespace Slang // // - Type* IRInst::getRate() + IRRate* IRInst::getRate() { - if(auto rateQualifiedType = type->As<RateQualifiedType>()) - return rateQualifiedType->rate; + if(auto rateQualifiedType = as<IRRateQualifiedType>(getFullType())) + return rateQualifiedType->getRate(); return nullptr; } - Type* IRInst::getDataType() + IRType* IRInst::getDataType() { - if(auto rateQualifiedType = type->As<RateQualifiedType>()) - return rateQualifiedType->valueType; + auto type = getFullType(); + if(auto rateQualifiedType = as<IRRateQualifiedType>(type)) + return rateQualifiedType->getValueType(); return type; } void IRInst::replaceUsesWith(IRInst* other) { + // Safety check: don't try to replace something with itself. + if(other == this) + return; + // We will walk through the list of uses for the current // instruction, and make them point to the other inst. IRUse* ff = firstUse; @@ -2683,7 +2858,6 @@ namespace Slang void IRInst::dispose() { IRObject::dispose(); - type = decltype(type)(); } // Insert this instruction into the same basic block @@ -2862,7 +3036,7 @@ namespace Slang IRGlobalVar* addGlobalVariable( IRModule* module, - Type* valueType) + IRType* valueType) { auto session = module->session; @@ -2872,9 +3046,6 @@ namespace Slang IRBuilder builder; builder.sharedBuilder = &shared; - - RefPtr<PtrType> ptrType = session->getPtrType(valueType); - return builder.createGlobalVar(valueType); } @@ -2965,11 +3136,11 @@ namespace Slang { struct Element { + IRStructKey* key; ScalarizedVal val; - DeclRef<Decl> declRef; }; - RefPtr<Type> type; + IRType* type; List<Element> elements; }; @@ -2978,8 +3149,8 @@ namespace Slang struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl { ScalarizedVal val; - RefPtr<Type> actualType; // the actual type of `val` - RefPtr<Type> pretendType; // the type this value pretends to have + IRType* actualType; // the actual type of `val` + IRType* pretendType; // the type this value pretends to have }; struct GlobalVaryingDeclarator @@ -2990,21 +3161,21 @@ namespace Slang }; Flavor flavor; - IntVal* elementCount; + IRInst* elementCount; GlobalVaryingDeclarator* next; }; struct GLSLSystemValueInfo { // The name of the built-in GLSL variable - char const* name; + char const* name; // The name of an outer array that wraps // the variable, in the case of a GS input char const* outerArrayName; // The required type of the built-in variable - RefPtr<Type> requiredType; + IRType* requiredType; }; void requireGLSLVersionImpl( @@ -3041,6 +3212,9 @@ namespace Slang { return sink; } + + IRBuilder* builder; + IRBuilder* getBuilder() { return builder; } }; GLSLSystemValueInfo* getGLSLSystemValueInfo( @@ -3059,7 +3233,7 @@ namespace Slang auto semanticName = semanticNameSpelling.ToLower(); - RefPtr<Type> requiredType; + IRType* requiredType = nullptr; if(semanticName == "sv_position") { @@ -3190,7 +3364,7 @@ namespace Slang } name = "gl_Layer"; - requiredType = context->session->getBuiltinType(BaseType::Int); + requiredType = context->getBuilder()->getBasicType(BaseType::Int); } else if (semanticName == "sv_sampleindex") { @@ -3262,7 +3436,7 @@ namespace Slang ScalarizedVal createSimpleGLSLGlobalVarying( GLSLLegalizationContext* context, IRBuilder* builder, - Type* inType, + IRType* inType, VarLayout* inVarLayout, TypeLayout* inTypeLayout, LayoutResourceKind kind, @@ -3279,7 +3453,7 @@ namespace Slang stage, &systemValueInfoStorage); - RefPtr<Type> type = inType; + IRType* type = inType; // A system-value semantic might end up needing to override the type // that the user specified. @@ -3295,12 +3469,12 @@ namespace Slang { assert(dd->flavor == GlobalVaryingDeclarator::Flavor::array); - RefPtr<ArrayExpressionType> arrayType = builder->getSession()->getArrayType( + auto arrayType = builder->getArrayType( type, dd->elementCount); RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); - arrayTypeLayout->type = arrayType; +// arrayTypeLayout->type = arrayType; arrayTypeLayout->rules = typeLayout->rules; arrayTypeLayout->originalElementTypeLayout = typeLayout; arrayTypeLayout->elementTypeLayout = typeLayout; @@ -3355,7 +3529,7 @@ namespace Slang // the actual type of the GLSL global. auto toType = inType; - if( !fromType->Equals(toType) ) + if( fromType != toType ) { RefPtr<ScalarizedTypeAdapterValImpl> typeAdapter = new ScalarizedTypeAdapterValImpl; typeAdapter->actualType = systemValueInfo->requiredType; @@ -3381,7 +3555,7 @@ namespace Slang ScalarizedVal createGLSLGlobalVaryingsImpl( GLSLLegalizationContext* context, IRBuilder* builder, - Type* type, + IRType* type, VarLayout* varLayout, TypeLayout* typeLayout, LayoutResourceKind kind, @@ -3389,31 +3563,31 @@ namespace Slang UInt bindingIndex, GlobalVaryingDeclarator* declarator) { - if( type->As<BasicExpressionType>() ) + if( as<IRBasicType>(type) ) { return createSimpleGLSLGlobalVarying( context, builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); } - else if( type->As<VectorExpressionType>() ) + else if( as<IRVectorType>(type) ) { return createSimpleGLSLGlobalVarying( context, builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); } - else if( type->As<MatrixExpressionType>() ) + else if( as<IRMatrixType>(type) ) { // TODO: a matrix-type varying should probably be handled like an array of rows return createSimpleGLSLGlobalVarying( context, builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); } - else if( auto arrayType = type->As<ArrayExpressionType>() ) + else if( auto arrayType = as<IRArrayType>(type) ) { // We will need to SOA-ize any nested types. - auto elementType = arrayType->baseType; - auto elementCount = arrayType->ArrayLength; + auto elementType = arrayType->getElementType(); + auto elementCount = arrayType->getElementCount(); auto arrayLayout = dynamic_cast<ArrayTypeLayout*>(typeLayout); SLANG_ASSERT(arrayLayout); auto elementTypeLayout = arrayLayout->elementTypeLayout; @@ -3434,7 +3608,7 @@ namespace Slang bindingIndex, &arrayDeclarator); } - else if( auto streamType = type->As<HLSLStreamOutputType>() ) + else if( auto streamType = as<IRHLSLStreamOutputType>(type)) { auto elementType = streamType->getElementType(); auto streamLayout = dynamic_cast<StreamOutputTypeLayout*>(typeLayout); @@ -3452,66 +3626,60 @@ namespace Slang bindingIndex, declarator); } - else if( auto declRefType = type->As<DeclRefType>() ) + else if(auto structType = as<IRStructType>(type)) { - auto declRef = declRefType->declRef; - if( auto structDeclRef = declRef.As<StructDecl>() ) - { - // This is either a user-defined struct, or a builtin type. - // TODO: exclude resource types here. + // We need to recurse down into the individual fields, + // and generate a variable for each of them. - // We need to recurse down into the individual fields, - // and generate a variable for each of them. + auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout); + SLANG_ASSERT(structTypeLayout); + RefPtr<ScalarizedTupleValImpl> tupleValImpl = new ScalarizedTupleValImpl(); - // Note: we can use the presence of a `StructTypeLayout` as - // a quick way to reject a bunch of types that aren't actually `struct`s - auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout); - if( structTypeLayout ) - { - RefPtr<ScalarizedTupleValImpl> tupleValImpl = new ScalarizedTupleValImpl(); + // Construct the actual type for the tuple (including any outer arrays) + IRType* fullType = type; + for( auto dd = declarator; dd; dd = dd->next ) + { + assert(dd->flavor == GlobalVaryingDeclarator::Flavor::array); + fullType = builder->getArrayType( + fullType, + dd->elementCount); + } - // Construct the actual type for the tuple (including any outer arrays) - RefPtr<Type> fullType = type; - for( auto dd = declarator; dd; dd = dd->next ) - { - assert(dd->flavor == GlobalVaryingDeclarator::Flavor::array); - fullType = builder->getSession()->getArrayType( - fullType, - dd->elementCount); - } + tupleValImpl->type = fullType; - tupleValImpl->type = fullType; + // Okay, we want to walk through the fields here, and + // generate one variable for each. + UInt fieldCounter = 0; + for(auto field : structType->getFields()) + { + UInt fieldIndex = fieldCounter++; - // Okay, we want to walk through the fields here, and - // generate one variable for each. - for( auto ff : structTypeLayout->fields ) - { - UInt fieldBindingIndex = bindingIndex; - if(auto fieldResInfo = ff->FindResourceInfo(kind)) - fieldBindingIndex += fieldResInfo->index; + auto fieldLayout = structTypeLayout->fields[fieldIndex]; - auto fieldVal = createGLSLGlobalVaryingsImpl( - context, - builder, - ff->typeLayout->type, - ff, - ff->typeLayout, - kind, - stage, - fieldBindingIndex, - declarator); - - ScalarizedTupleValImpl::Element element; - element.val = fieldVal; - element.declRef = ff->varDecl; - - tupleValImpl->elements.Add(element); - } + UInt fieldBindingIndex = bindingIndex; + if(auto fieldResInfo = fieldLayout->FindResourceInfo(kind)) + fieldBindingIndex += fieldResInfo->index; - return ScalarizedVal::tuple(tupleValImpl); - } + auto fieldVal = createGLSLGlobalVaryingsImpl( + context, + builder, + field->getFieldType(), + fieldLayout, + fieldLayout->typeLayout, + kind, + stage, + fieldBindingIndex, + declarator); + + ScalarizedTupleValImpl::Element element; + element.val = fieldVal; + element.key = field->getKey(); + + tupleValImpl->elements.Add(element); } + + return ScalarizedVal::tuple(tupleValImpl); } // Default case is to fall back on the simple behavior @@ -3523,7 +3691,7 @@ namespace Slang ScalarizedVal createGLSLGlobalVaryings( GLSLLegalizationContext* context, IRBuilder* builder, - Type* type, + IRType* type, VarLayout* layout, LayoutResourceKind kind, Stage stage) @@ -3536,27 +3704,44 @@ namespace Slang builder, type, layout, layout->typeLayout, kind, stage, bindingIndex, nullptr); } + IRType* getFieldType( + IRType* baseType, + IRStructKey* fieldKey) + { + if(auto structType = as<IRStructType>(baseType)) + { + for(auto ff : structType->getFields()) + { + if(ff->getKey() == fieldKey) + return ff->getFieldType(); + } + } + + SLANG_UNEXPECTED("no such field"); + UNREACHABLE_RETURN(nullptr); + } + ScalarizedVal extractField( IRBuilder* builder, ScalarizedVal const& val, UInt fieldIndex, - DeclRef<Decl> fieldDeclRef) + IRStructKey* fieldKey) { switch( val.flavor ) { case ScalarizedVal::Flavor::value: return ScalarizedVal::value( builder->emitFieldExtract( - GetType(fieldDeclRef.As<VarDeclBase>()), + getFieldType(val.irValue->getDataType(), fieldKey), val.irValue, - builder->getDeclRefVal(fieldDeclRef))); + fieldKey)); case ScalarizedVal::Flavor::address: return ScalarizedVal::address( builder->emitFieldAddress( - GetType(fieldDeclRef.As<VarDeclBase>()), + getFieldType(val.irValue->getDataType(), fieldKey), val.irValue, - builder->getDeclRefVal(fieldDeclRef))); + fieldKey)); case ScalarizedVal::Flavor::tuple: { @@ -3574,8 +3759,8 @@ namespace Slang ScalarizedVal adaptType( IRBuilder* builder, IRInst* val, - Type* toType, - Type* /*fromType*/) + IRType* toType, + IRType* /*fromType*/) { // TODO: actually consider what needs to go on here... return ScalarizedVal::value(builder->emitConstructorInst( @@ -3587,8 +3772,8 @@ namespace Slang ScalarizedVal adaptType( IRBuilder* builder, ScalarizedVal const& val, - Type* toType, - Type* fromType) + IRType* toType, + IRType* fromType) { switch( val.flavor ) { @@ -3647,7 +3832,7 @@ namespace Slang builder, left, ee, - rightElement.declRef); + rightElement.key); assign(builder, leftElementVal, rightElement.val); } } @@ -3672,7 +3857,7 @@ namespace Slang builder, right, ee, - leftTupleVal->elements[ee].declRef); + leftTupleVal->elements[ee].key); assign(builder, leftTupleVal->elements[ee].val, rightElementVal); } } @@ -3699,7 +3884,7 @@ namespace Slang ScalarizedVal getSubscriptVal( IRBuilder* builder, - Type* elementType, + IRType* elementType, ScalarizedVal val, IRInst* indexVal) { @@ -3715,7 +3900,7 @@ namespace Slang case ScalarizedVal::Flavor::address: return ScalarizedVal::address( builder->emitElementAddress( - builder->getSession()->getPtrType(elementType), + builder->getPtrType(elementType), val.irValue, indexVal)); @@ -3729,18 +3914,10 @@ namespace Slang UInt elementCount = inputTuple->elements.Count(); UInt elementCounter = 0; - auto declRefType = dynamic_cast<DeclRefType*>(elementType); - SLANG_RELEASE_ASSERT(declRefType); - - auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDecl>(); - SLANG_RELEASE_ASSERT(aggTypeDeclRef); - - for(auto fieldDeclRef : getMembersOfType<StructField>(aggTypeDeclRef)) + auto structType = as<IRStructType>(elementType); + for(auto field : structType->getFields()) { - if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) - continue; - - auto tupleElementType = GetType(fieldDeclRef); + auto tupleElementType = field->getFieldType(); UInt elementIndex = elementCounter++; @@ -3748,7 +3925,7 @@ namespace Slang auto inputElement = inputTuple->elements[elementIndex]; ScalarizedTupleValImpl::Element resultElement; - resultElement.declRef = inputElement.declRef; + resultElement.key = inputElement.key; resultElement.val = getSubscriptVal( builder, tupleElementType, @@ -3770,7 +3947,7 @@ namespace Slang ScalarizedVal getSubscriptVal( IRBuilder* builder, - Type* elementType, + IRType* elementType, ScalarizedVal val, UInt index) { @@ -3779,7 +3956,7 @@ namespace Slang elementType, val, builder->getIntValue( - builder->getSession()->getIntType(), + builder->getIntType(), index)); } @@ -3797,7 +3974,7 @@ namespace Slang UInt elementCount = tupleVal->elements.Count(); auto type = tupleVal->type; - if( auto arrayType = type.As<ArrayExpressionType>() ) + if( auto arrayType = as<IRArrayType>(type)) { // The tuple represent an array, which means that the // individual elements are expected to yield arrays as well. @@ -3806,13 +3983,13 @@ namespace Slang // then use these to construct our result. List<IRInst*> arrayElementVals; - UInt arrayElementCount = (UInt) GetIntVal(arrayType->ArrayLength); + UInt arrayElementCount = (UInt) GetIntVal(arrayType->getElementCount()); for( UInt ii = 0; ii < arrayElementCount; ++ii ) { auto arrayElementPseudoVal = getSubscriptVal( builder, - arrayType->baseType, + arrayType->getElementType(), val, ii); @@ -3945,6 +4122,8 @@ namespace Slang builder.sharedBuilder = &shared; builder.setInsertInto(func); + context.builder = &builder; + // We will start by looking at the return type of the // function, because that will enable us to do an // early-out check to avoid more work. @@ -3953,7 +4132,7 @@ namespace Slang // a `void` return type, because there is no work // to be done on its return value in that case. auto resultType = func->getResultType(); - if( resultType->Equals(session->getVoidType()) ) + if(as<IRVoidType>(resultType)) { // In this case, the function doesn't return a value // so we don't need to transform its `return` sites. @@ -4060,10 +4239,10 @@ namespace Slang // don't fit into the standard varying model. // For right now we are only doing special-case handling // of geometry shader output streams. - if( auto paramPtrType = paramType->As<OutTypeBase>() ) + if( auto paramPtrType = as<IROutTypeBase>(paramType) ) { auto valueType = paramPtrType->getValueType(); - if( auto gsStreamType = valueType->As<HLSLStreamOutputType>() ) + if( auto gsStreamType = as<IRHLSLStreamOutputType>(valueType) ) { // An output stream type like `TriangleStream<Foo>` should // more or less translate into `out Foo` (plus scalarization). @@ -4097,7 +4276,7 @@ namespace Slang // Is it calling the append operation? auto callee = ii->getOperand(0); - while( callee->op == kIROp_specialize ) + while( callee->op == kIROp_Specialize ) { callee = ((IRSpecialize*) callee)->getOperand(0); } @@ -4132,7 +4311,7 @@ namespace Slang // Is the parameter type a special pointer type // that indicates the parameter is used for `out` // or `inout` access? - if(auto paramPtrType = paramType->As<OutTypeBase>() ) + if(auto paramPtrType = as<IROutTypeBase>(paramType) ) { // Okay, we have the more interesting case here, // where the parameter was being passed by reference. @@ -4145,7 +4324,7 @@ namespace Slang auto localVariable = builder.emitVar(valueType); auto localVal = ScalarizedVal::address(localVariable); - if( auto inOutType = paramPtrType->As<InOutType>() ) + if( auto inOutType = as<IRInOutType>(paramPtrType) ) { // In the `in out` case we need to declare two // sets of global variables: one for the `in` @@ -4236,10 +4415,11 @@ namespace Slang // Finally, we need to patch up the type of the entry point, // because it is no longer accurate. - RefPtr<FuncType> voidFuncType = new FuncType(); - voidFuncType->setSession(session); - voidFuncType->resultType = session->getVoidType(); - func->type = voidFuncType; + IRFuncType* voidFuncType = builder.getFuncType( + 0, + nullptr, + builder.getVoidType()); + func->setFullType(voidFuncType); // TODO: we should technically be constructing // a new `EntryPointLayout` here to reflect @@ -4260,6 +4440,15 @@ namespace Slang RefPtr<IRSpecSymbol> nextWithSameName; }; + struct IRSpecEnv + { + IRSpecEnv* parent = nullptr; + + // A map from original values to their cloned equivalents. + typedef Dictionary<IRInst*, IRInst*> ClonedValueDictionary; + ClonedValueDictionary clonedValues; + }; + struct IRSharedSpecContext { // The code-generation target in use @@ -4277,16 +4466,38 @@ namespace Slang typedef Dictionary<Name*, RefPtr<IRSpecSymbol>> SymbolDictionary; SymbolDictionary symbols; - // A map from values in the original IR module - // to their equivalent in the cloned module. - typedef Dictionary<IRInst*, IRInst*> ClonedValueDictionary; - ClonedValueDictionary clonedValues; - SharedIRBuilder sharedBuilderStorage; IRBuilder builderStorage; - // Non-generic functions to be processed (for generic specialization context) - List<IRFunc*> workList; + // The "global" specialization environment. + IRSpecEnv globalEnv; + }; + + struct IRSharedGenericSpecContext : IRSharedSpecContext + { + // Instructions to be processed (for generic specialization context) + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + void addToWorkList(IRInst* inst) + { + if(!workListSet.Contains(inst)) + { + workList.Add(inst); + workListSet.Add(inst); + } + } + IRInst* popWorkList() + { + UInt count = workList.Count(); + if(count != 0) + { + IRInst* inst = workList[count - 1]; + workList.FastRemoveAt(count - 1); + workListSet.Remove(inst); + return inst; + } + return nullptr; + } }; struct IRSpecContextBase @@ -4305,13 +4516,23 @@ namespace Slang IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } - IRSharedSpecContext::ClonedValueDictionary& getClonedValues() { return getShared()->clonedValues; } + // The current specialization environment to use. + IRSpecEnv* env = nullptr; + IRSpecEnv* getEnv() + { + // TODO: need to actually establish environments on contexts we create. + // + // Or more realistically we need to change the whole approach + // to specialization and cloning so that we don't try to share + // logic between two very different cases. + + + return env; + } // The IR builder to use for creating nodes IRBuilder* builder; - SubstitutionSet subst; - // A callback to be used when a value that is not registerd in `clonedValues` // is needed during cloning. This gives the subtype a chance to intercept // the operation and clone (or not) as needed. @@ -4319,24 +4540,6 @@ namespace Slang { return originalVal; } - - // A callback used to clone (or not) types. - virtual RefPtr<Type> maybeCloneType(Type* originalType) - { - return originalType; - } - - // A callback used to clone (or not) a declaration reference - virtual DeclRef<Decl> maybeCloneDeclRef(DeclRef<Decl> const& declRef) - { - return declRef; - } - - // A callback used to clone (or not) a Val - virtual RefPtr<Val> maybeCloneVal(Val* val) - { - return val; - } }; void registerClonedValue( @@ -4347,19 +4550,12 @@ namespace Slang if(!originalValue) return; - // Note: setting the entry direclty here rather than - // using `Add` or `AddIfNotExists` because we can conceivably - // clone the same value (e.g., a basic block inside a generic - // function) multiple times, and that is okay, and we really - // just need to keep track of the most recent value. - - // TODO: The same thing could potentially be handled more - // cleanly by having a notion of scoping for these cloned-value - // mappings, so that we register cloned values for things - // inside of a function to a temporary mapping that we - // throw away after the function is done. - - context->getClonedValues()[originalValue] = clonedValue; + // TODO: now that things are scoped using environments, we + // shouldn't be running into the cases where a value with + // the same key already exists. This should be changed to + // an `Add()` call. + // + context->getEnv()->clonedValues[originalValue] = clonedValue; } // Information on values to use when registering a cloned value @@ -4425,6 +4621,22 @@ namespace Slang } break; + case kIRDecorationOp_Semantic: + { + auto originalDecoration = (IRSemanticDecoration*)dd; + auto newDecoration = context->builder->addDecoration<IRSemanticDecoration>(clonedValue); + newDecoration->semanticName = originalDecoration->semanticName; + } + break; + + case kIRDecorationOp_InterpolationMode: + { + auto originalDecoration = (IRInterpolationModeDecoration*)dd; + auto newDecoration = context->builder->addDecoration<IRInterpolationModeDecoration>(clonedValue); + newDecoration->mode = originalDecoration->mode; + } + break; + default: // Don't clone any decorations we don't understand. break; @@ -4435,46 +4647,37 @@ namespace Slang clonedValue->sourceLoc = originalValue->sourceLoc; } + // We use an `IRSpecContext` for the case where we are cloning + // code from one or more input modules to create a "linked" output + // module. Along the way, we will resolve profile-specific functions + // to the best definition for a given target. + // struct IRSpecContext : IRSpecContextBase { // Override the "maybe clone" logic so that we always clone virtual IRInst* maybeCloneValue(IRInst* originalVal) override; - - // Override teh "maybe clone" logic so that we carefully - // clone any IR proxy values inside substitutions - virtual DeclRef<Decl> maybeCloneDeclRef(DeclRef<Decl> const& declRef) override; - - virtual RefPtr<Type> maybeCloneType(Type* originalType) override; - virtual RefPtr<Val> maybeCloneVal(Val* val) override; }; IRGlobalValue* cloneGlobalValue(IRSpecContext* context, IRGlobalValue* originalVal); - RefPtr<Substitutions> cloneSubstitutions( - IRSpecContext* context, - Substitutions* subst); - - RefPtr<Type> IRSpecContext::maybeCloneType(Type* originalType) - { - return originalType->Substitute(subst).As<Type>(); - } - RefPtr<Val> IRSpecContext::maybeCloneVal(Val * val) - { - return val->Substitute(subst); - } + IRInst* cloneValue( + IRSpecContextBase* context, + IRInst* originalValue); + IRType* cloneType( + IRSpecContextBase* context, + IRType* originalType); IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) { - switch (originalValue->op) + if (auto globalValue = as<IRGlobalValue>(originalValue)) { - case kIROp_global_var: - case kIROp_global_constant: - case kIROp_Func: - case kIROp_witness_table: - return cloneGlobalValue(this, (IRGlobalValue*) originalValue); + return cloneGlobalValue(this, globalValue); + } + switch (originalValue->op) + { case kIROp_boolConst: { IRConstant* c = (IRConstant*)originalValue; @@ -4486,70 +4689,43 @@ namespace Slang case kIROp_IntLit: { IRConstant* c = (IRConstant*)originalValue; - return builder->getIntValue(c->type, c->u.intVal); + return builder->getIntValue(cloneType(this, c->getDataType()), c->u.intVal); } break; case kIROp_FloatLit: { IRConstant* c = (IRConstant*)originalValue; - return builder->getFloatValue(c->type, c->u.floatVal); + return builder->getFloatValue(cloneType(this, c->getDataType()), c->u.floatVal); } break; - case kIROp_decl_ref: + default: { - IRDeclRef* od = (IRDeclRef*)originalValue; - auto newDeclRef = od->declRef; + // In the deafult case, assume that we have some sort of "hoistable" + // instruction that requires us to create a clone of it. - // if the declRef is one of the __generic_param decl being substituted by subst - // return the substituted decl - if (subst.globalGenParamSubstitutions) + UInt argCount = originalValue->getOperandCount(); + IRInst* clonedValue = createInstWithTrailingArgs<IRInst>( + builder, + originalValue->op, + cloneType(this, originalValue->getFullType()), + 0, nullptr, + argCount, nullptr); + registerClonedValue(this, clonedValue, originalValue); + for (UInt aa = 0; aa < argCount; ++aa) { - int diff = 0; - newDeclRef = od->declRef.SubstituteImpl(subst, &diff); - for (auto globalGenSubst = subst.globalGenParamSubstitutions; globalGenSubst; globalGenSubst = globalGenSubst->outer) - { - if (!globalGenSubst) - continue; - if (newDeclRef.getDecl() == globalGenSubst->paramDecl) - return builder->getTypeVal(globalGenSubst->actualType.As<Type>()); - else if (auto genConstraint = newDeclRef.As<GenericTypeConstraintDecl>()) - { - // a decl-ref to GenericTypeConstraintDecl as a result of - // referencing a generic parameter type should be replaced with - // the actual witness table - if (genConstraint.getDecl()->ParentDecl == globalGenSubst->paramDecl) - { - // find the witness table from subst - for (auto witness : globalGenSubst->witnessTables) - { - if (witness.Key->EqualsVal(GetSup(genConstraint))) - { - auto proxyVal = witness.Value.As<IRProxyVal>(); - SLANG_ASSERT(proxyVal); - return proxyVal->inst.get(); - } - } - } - } - } + IRInst* originalArg = originalValue->getOperand(aa); + IRInst* clonedArg = cloneValue(this, originalArg); + clonedValue->getOperands()[aa].init(clonedValue, clonedArg); } - auto declRef = maybeCloneDeclRef(newDeclRef); - return builder->getDeclRefVal(declRef); - } - break; - case kIROp_TypeType: - { - IRInst* od = (IRInst*)originalValue; - int ioDiff = 0; - auto newType = od->type->SubstituteImpl(subst, &ioDiff); - return builder->getTypeVal(newType.As<Type>()); + cloneDecorations(this, clonedValue, originalValue); + + addHoistableInst(builder, clonedValue); + + return clonedValue; } break; - default: - SLANG_UNEXPECTED("no value registered for IR value"); - UNREACHABLE_RETURN(nullptr); } } @@ -4557,102 +4733,41 @@ namespace Slang IRSpecContextBase* context, IRInst* originalValue); - RefPtr<Val> cloneSubstitutionArg( - IRSpecContext* context, - Val* val) + // Find a pre-existing cloned value, or return null if none is available. + IRInst* findClonedValue( + IRSpecContextBase* context, + IRInst* originalValue) { - if (auto proxyVal = dynamic_cast<IRProxyVal*>(val)) - { - auto newIRVal = cloneValue(context, proxyVal->inst.get()); - - RefPtr<IRProxyVal> newProxyVal = new IRProxyVal(); - newProxyVal->inst.init(nullptr, newIRVal); - return newProxyVal; - } - else if (auto type = dynamic_cast<Type*>(val)) - { - return context->maybeCloneType(type); - } - else + IRInst* clonedValue = nullptr; + for (auto env = context->getEnv(); env; env = env->parent) { - return context->maybeCloneVal(val); + if (env->clonedValues.TryGetValue(originalValue, clonedValue)) + { + return clonedValue; + } } - } - RefPtr<GenericSubstitution> cloneGenericSubst(IRSpecContext* context, GenericSubstitution* genSubst) - { - if (!genSubst) - return nullptr; - - RefPtr<GenericSubstitution> newSubst = new GenericSubstitution(); - newSubst->outer = cloneGenericSubst(context, genSubst->outer); - newSubst->genericDecl = genSubst->genericDecl; - - for (auto arg : genSubst->args) - { - auto newArg = cloneSubstitutionArg(context, arg); - newSubst->args.Add(newArg); - } - return newSubst; + return nullptr; } - RefPtr<GlobalGenericParamSubstitution> cloneGlobalGenericSubst(IRSpecContext* context, GlobalGenericParamSubstitution* subst) + IRInst* cloneValue( + IRSpecContextBase* context, + IRInst* originalValue) { - if (!subst) + if (!originalValue) return nullptr; - auto newSubst = new GlobalGenericParamSubstitution(); - newSubst->actualType = subst->actualType; - newSubst->paramDecl = subst->paramDecl; - newSubst->witnessTables = subst->witnessTables; - newSubst->outer = cloneGlobalGenericSubst(context, subst->outer); - return newSubst; - } - SubstitutionSet cloneSubstitutions( - IRSpecContext* context, - SubstitutionSet subst) - { - SubstitutionSet rs; - if (!subst) - return rs; - rs.genericSubstitutions = cloneGenericSubst(context, subst.genericSubstitutions); - rs.globalGenParamSubstitutions = cloneGlobalGenericSubst(context, subst.globalGenParamSubstitutions); - if (auto thisSubst = subst.thisTypeSubstitution) - { - RefPtr<ThisTypeSubstitution> newSubst = new ThisTypeSubstitution(); - newSubst->sourceType = thisSubst->sourceType; - rs.thisTypeSubstitution = newSubst; - } - return rs; - } - - DeclRef<Decl> IRSpecContext::maybeCloneDeclRef(DeclRef<Decl> const& declRef) - { - // Un-specialized decl? Nothing to do. - if (!declRef.substitutions) - return declRef; - - DeclRef<Decl> newDeclRef = declRef; - - // Scan through substitutions and clone as needed. - // - // TODO: this is wasteful since we clone *everything* - newDeclRef.substitutions = cloneSubstitutions(this, declRef.substitutions); + if (IRInst* clonedValue = findClonedValue(context, originalValue)) + return clonedValue; - return newDeclRef; + return context->maybeCloneValue(originalValue); } - IRInst* cloneValue( + IRType* cloneType( IRSpecContextBase* context, - IRInst* originalValue) + IRType* originalType) { - IRInst* clonedValue = nullptr; - if (context->getClonedValues().TryGetValue(originalValue, clonedValue)) - { - return clonedValue; - } - - return context->maybeCloneValue(originalValue); + return (IRType*)cloneValue(context, originalType); } IRInst* maybeCloneValueWithMangledName( @@ -4670,50 +4785,19 @@ namespace Slang } return cloneValue(context, originalValue); } - - void cloneInst( + + IRInst* cloneInst( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* originalInst, + IROriginalValuesForClone const& originalValues); + + IRInst* cloneInst( IRSpecContextBase* context, - IRBuilder* builder, - IRInst* originalInst) + IRBuilder* builder, + IRInst* originalInst) { - switch (originalInst->op) - { - // TODO: are there any instruction types that need to be handled - // specially here? That would be anything that has more state - // than is visible in its operand list... - case 0: // nothing yet - default: - { - // The common case is that we just need to construct a cloned - // instruction with the right number of operands, intialize - // it, and then add it to the sequence. - UInt argCount = originalInst->getOperandCount(); - IRInst* clonedInst = createInstWithTrailingArgs<IRInst>( - builder, originalInst->op, - context->maybeCloneType(originalInst->type), - 0, nullptr, - argCount, nullptr); - registerClonedValue(context, clonedInst, originalInst); - auto oldBuilder = context->builder; - context->builder = builder; - for (UInt aa = 0; aa < argCount; ++aa) - { - IRInst* originalArg = originalInst->getOperand(aa); - IRInst* clonedArg; - if (originalArg->op == kIROp_witness_table) - clonedArg = cloneGlobalValueWithMangledName((IRSpecContext*)context, - ((IRGlobalValue*)originalArg)->mangledName, (IRGlobalValue*)originalArg); - else - clonedArg = cloneValue(context, originalArg); - clonedInst->getOperands()[aa].init(clonedInst, clonedArg); - } - builder->addInst(clonedInst); - context->builder = oldBuilder; - cloneDecorations(context, clonedInst, originalInst); - } - - break; - } + return cloneInst(context, builder, originalInst, originalInst); } void cloneGlobalValueWithCodeCommon( @@ -4722,17 +4806,18 @@ namespace Slang IRGlobalValueWithCode* originalValue); IRGlobalVar* cloneGlobalVarImpl( - IRSpecContext* context, - IRGlobalVar* originalVar, + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalVar* originalVar, IROriginalValuesForClone const& originalValues) { - auto clonedVar = context->builder->createGlobalVar( - context->maybeCloneType(originalVar->getDataType()->getValueType())); + auto clonedVar = builder->createGlobalVar( + cloneType(context, originalVar->getDataType()->getValueType())); if(auto rate = originalVar->getRate() ) { - clonedVar->type = context->builder->getSession()->getRateQualifiedType( - rate, clonedVar->type); + clonedVar->setFullType(builder->getRateQualifiedType( + rate, clonedVar->getFullType())); } registerClonedValue(context, clonedVar, originalValues); @@ -4745,7 +4830,7 @@ namespace Slang VarLayout* layout = nullptr; if (context->globalVarLayouts.TryGetValue(mangledName, layout)) { - context->builder->addLayoutDecoration(clonedVar, layout); + builder->addLayoutDecoration(clonedVar, layout); } // Clone any code in the body of the variable, since this @@ -4759,11 +4844,13 @@ namespace Slang } IRGlobalConstant* cloneGlobalConstantImpl( - IRSpecContext* context, - IRGlobalConstant* originalVal, + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalConstant* originalVal, IROriginalValuesForClone const& originalValues) { - auto clonedVal = context->builder->createGlobalConstant(context->maybeCloneType(originalVal->getFullType())); + auto clonedVal = builder->createGlobalConstant( + cloneType(context, originalVal->getFullType())); registerClonedValue(context, clonedVal, originalValues); auto mangledName = originalVal->mangledName; @@ -4781,48 +4868,111 @@ namespace Slang return clonedVal; } - IRWitnessTable* cloneWitnessTableImpl( - IRSpecContextBase* context, - IRWitnessTable* originalTable, + IRGeneric* cloneGenericImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGeneric* originalVal, + IROriginalValuesForClone const& originalValues) + { + auto clonedVal = builder->emitGeneric(); + registerClonedValue(context, clonedVal, originalValues); + + auto mangledName = originalVal->mangledName; + clonedVal->mangledName = mangledName; + + cloneDecorations(context, clonedVal, originalVal); + + // Clone any code in the body of the generic, since this + // computes its result value. + cloneGlobalValueWithCodeCommon( + context, + clonedVal, + originalVal); + + return clonedVal; + } + + void cloneSimpleGlobalValueImpl( + IRSpecContextBase* context, + IRGlobalValue* originalInst, IROriginalValuesForClone const& originalValues, - IRWitnessTable* dstTable = nullptr, - bool registerValue = true) + IRGlobalValue* clonedInst, + bool registerValue = true) { - auto clonedTable = dstTable ? dstTable : context->builder->createWitnessTable(); if (registerValue) - registerClonedValue(context, clonedTable, originalValues); + registerClonedValue(context, clonedInst, originalValues); - auto mangledName = originalTable->mangledName; - - clonedTable->mangledName = mangledName; - clonedTable->genericDecl = originalTable->genericDecl; - clonedTable->subTypeDeclRef = originalTable->subTypeDeclRef; - clonedTable->supTypeDeclRef = originalTable->supTypeDeclRef; - cloneDecorations(context, clonedTable, originalTable); + auto mangledName = originalInst->mangledName; + clonedInst->mangledName = mangledName; - // Clone the entries in the witness table as well - for(auto originalEntry : originalTable->getEntries() ) - { - auto clonedKey = cloneValue(context, originalEntry->requirementKey.get()); - - // if a global val with the mangled name already exists, don't clone again - auto clonedVal = maybeCloneValueWithMangledName(context, (IRGlobalValue*)(originalEntry->satisfyingVal.get())); + cloneDecorations(context, clonedInst, originalInst); - /*auto clonedEntry = */context->builder->createWitnessTableEntry( - clonedTable, - clonedKey, - clonedVal); + // Set up an IR builder for inserting into the inst + IRBuilder builderStorage = *context->builder; + IRBuilder* builder = &builderStorage; + builder->setInsertInto(clonedInst); + + // Clone any children of the instruction + for (auto child : originalInst->getChildren()) + { + cloneInst(context, builder, child); } + } + IRStructKey* cloneStructKeyImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRStructKey* originalVal, + IROriginalValuesForClone const& originalValues) + { + auto clonedVal = builder->createStructKey(); + cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); + return clonedVal; + } + + IRGlobalGenericParam* cloneGlobalGenericParamImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalGenericParam* originalVal, + IROriginalValuesForClone const& originalValues) + { + auto clonedVal = builder->emitGlobalGenericParam(); + cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); + return clonedVal; + } + + + IRWitnessTable* cloneWitnessTableImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRWitnessTable* originalTable, + IROriginalValuesForClone const& originalValues, + IRWitnessTable* dstTable = nullptr, + bool registerValue = true) + { + auto clonedTable = dstTable ? dstTable : builder->createWitnessTable(); + cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); return clonedTable; } IRWitnessTable* cloneWitnessTableWithoutRegistering( IRSpecContextBase* context, + IRBuilder* builder, IRWitnessTable* originalTable, IRWitnessTable* dstTable = nullptr) { - return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone(), dstTable, false); + return cloneWitnessTableImpl(context, builder, originalTable, IROriginalValuesForClone(), dstTable, false); + } + + IRStructType* cloneStructTypeImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRStructType* originalStruct, + IROriginalValuesForClone const& originalValues) + { + auto clonedStruct = builder->createStructType(); + cloneSimpleGlobalValueImpl(context, originalStruct, originalValues, clonedStruct); + return clonedStruct; } void cloneGlobalValueWithCodeCommon( @@ -4887,11 +5037,14 @@ namespace Slang } - void checkIRDuplicate(IRParentInst* moduleInst, Name* mangledName) + void checkIRDuplicate(IRInst* inst, IRParentInst* moduleInst, Name* mangledName) { #ifdef _DEBUG for (auto child : moduleInst->getChildren()) { + if (child == inst) + continue; + if (child->op == kIROp_Func) { auto extName = ((IRGlobalValue*)child)->mangledName; @@ -4902,6 +5055,7 @@ namespace Slang } } #else + SLANG_UNREFERENCED_PARAMETER(inst); SLANG_UNREFERENCED_PARAMETER(moduleInst); SLANG_UNREFERENCED_PARAMETER(mangledName); #endif @@ -4915,9 +5069,7 @@ namespace Slang { // First clone all the simple properties. clonedFunc->mangledName = originalFunc->mangledName; - clonedFunc->genericDecls = originalFunc->genericDecls; - clonedFunc->specializedGenericLevel = originalFunc->specializedGenericLevel; - clonedFunc->type = context->maybeCloneType(originalFunc->type); + clonedFunc->setFullType(cloneType(context, originalFunc->getFullType())); cloneDecorations(context, clonedFunc, originalFunc); @@ -4930,10 +5082,9 @@ namespace Slang // it needs to follow its dependencies. // // TODO: This isn't really a good requirement to place on the IR... - clonedFunc->removeFromParent(); + clonedFunc->moveToEnd(); if (checkDuplicate) - checkIRDuplicate(context->getModule()->getModuleInst(), clonedFunc->mangledName); - clonedFunc->insertAtEnd(context->getModule()->getModuleInst()); + checkIRDuplicate(clonedFunc, context->getModule()->getModuleInst(), clonedFunc->mangledName); } IRFunc* specializeIRForEntryPoint( @@ -5072,17 +5223,51 @@ namespace Slang return result; } + IRInst* findGenericReturnVal(IRGeneric* generic) + { + auto lastBlock = generic->getLastBlock(); + if (!lastBlock) + return nullptr; + + auto returnInst = as<IRReturnVal>(lastBlock->getTerminator()); + if (!returnInst) + return nullptr; + + auto val = returnInst->getVal(); + return val; + } + bool isDefinition( - IRGlobalValue* val) + IRGlobalValue* inVal) { + IRInst* val = inVal; + // unwrap any generic declarations to see + // the value they return. + for(;;) + { + auto genericInst = as<IRGeneric>(val); + if(!genericInst) + break; + + auto returnVal = findGenericReturnVal(genericInst); + if(!returnVal) + break; + + val = returnVal; + } + switch (val->op) { - case kIROp_witness_table: - case kIROp_global_var: - case kIROp_global_constant: + case kIROp_WitnessTable: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: case kIROp_Func: + case kIROp_Generic: return ((IRParentInst*)val)->getFirstChild() != nullptr; + case kIROp_StructType: + return true; + default: return false; } @@ -5146,51 +5331,92 @@ namespace Slang } IRFunc* cloneFuncImpl( - IRSpecContext* context, - IRFunc* originalFunc, + IRSpecContextBase* context, + IRBuilder* builder, + IRFunc* originalFunc, IROriginalValuesForClone const& originalValues) { - auto clonedFunc = context->builder->createFunc(); + auto clonedFunc = builder->createFunc(); registerClonedValue(context, clonedFunc, originalValues); cloneFunctionCommon(context, clonedFunc, originalFunc); return clonedFunc; } - // Directly clone a global value, based on a single definition/declaration, `originalVal`. - // The symbol `sym` will thread together other declarations of the same value, and - // we will register the new value as the cloned version of all of those. - IRGlobalValue* cloneGlobalValueImpl( - IRSpecContext* context, - IRGlobalValue* originalVal, - IRSpecSymbol* sym) - { - if( !originalVal ) - { - SLANG_UNEXPECTED("cloning a null value"); - UNREACHABLE_RETURN(nullptr); - } - switch( originalVal->op ) + IRInst* cloneInst( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* originalInst, + IROriginalValuesForClone const& originalValues) + { + switch (originalInst->op) { + // We need to special-case any instruction that is not + // allocated like an ordinary `IRInst` with trailing args. case kIROp_Func: - return cloneFuncImpl(context, (IRFunc*) originalVal, sym); + return cloneFuncImpl(context, builder, cast<IRFunc>(originalInst), originalValues); + + case kIROp_GlobalVar: + return cloneGlobalVarImpl(context, builder, cast<IRGlobalVar>(originalInst), originalValues); + + case kIROp_GlobalConstant: + return cloneGlobalConstantImpl(context, builder, cast<IRGlobalConstant>(originalInst), originalValues); + + case kIROp_WitnessTable: + return cloneWitnessTableImpl(context, builder, cast<IRWitnessTable>(originalInst), originalValues); - case kIROp_global_var: - return cloneGlobalVarImpl(context, (IRGlobalVar*)originalVal, sym); + case kIROp_StructType: + return cloneStructTypeImpl(context, builder, cast<IRStructType>(originalInst), originalValues); + + case kIROp_Generic: + return cloneGenericImpl(context, builder, cast<IRGeneric>(originalInst), originalValues); - case kIROp_global_constant: - return cloneGlobalConstantImpl(context, (IRGlobalConstant*)originalVal, sym); + case kIROp_StructKey: + return cloneStructKeyImpl(context, builder, cast<IRStructKey>(originalInst), originalValues); - case kIROp_witness_table: - return cloneWitnessTableImpl(context, (IRWitnessTable*)originalVal, sym); + case kIROp_GlobalGenericParam: + return cloneGlobalGenericParamImpl(context, builder, cast<IRGlobalGenericParam>(originalInst), originalValues); default: - SLANG_UNEXPECTED("unknown global value kind"); - UNREACHABLE_RETURN(nullptr); + break; } + // The common case is that we just need to construct a cloned + // instruction with the right number of operands, intialize + // it, and then add it to the sequence. + UInt argCount = originalInst->getOperandCount(); + IRInst* clonedInst = createInstWithTrailingArgs<IRInst>( + builder, originalInst->op, + cloneType(context, originalInst->getFullType()), + 0, nullptr, + argCount, nullptr); + registerClonedValue(context, clonedInst, originalValues); + auto oldBuilder = context->builder; + context->builder = builder; + for (UInt aa = 0; aa < argCount; ++aa) + { + IRInst* originalArg = originalInst->getOperand(aa); + IRInst* clonedArg = cloneValue(context, originalArg); + clonedInst->getOperands()[aa].init(clonedInst, clonedArg); + } + builder->addInst(clonedInst); + context->builder = oldBuilder; + cloneDecorations(context, clonedInst, originalInst); + + return clonedInst; } + IRGlobalValue* cloneGlobalValueImpl( + IRSpecContext* context, + IRGlobalValue* originalInst, + IROriginalValuesForClone const& originalValues) + { + auto clonedValue = cloneInst(context, &context->shared->builderStorage, originalInst, originalValues); + clonedValue->moveToEnd(); + return cast<IRGlobalValue>(clonedValue); + } + + // Clone a global value, which has the given `mangledName`. // The `originalVal` is a known global IR value with that name, if one is available. // (It is okay for this parameter to be null). @@ -5202,7 +5428,7 @@ namespace Slang // If the global value being cloned is already in target module, don't clone // Why checking this? // When specializing a generic function G (which is already in target module), - // where G calls a normal function F (which is already in target module), + // where G calls a normal function F (which is already in target module), // then when we are making a copy of G via cloneFuncCommom(), it will recursively clone F, // however we don't want to make a duplicate of F in the target module. if (originalVal->getParent() == context->getModule()->getModuleInst()) @@ -5210,17 +5436,19 @@ namespace Slang // Check if we've already cloned this value, for the case where // an original value has already been established. - IRInst* clonedVal = nullptr; - if( originalVal && context->getClonedValues().TryGetValue(originalVal, clonedVal) ) + if (originalVal) { - return (IRGlobalValue*) clonedVal; + if (IRInst* clonedVal = findClonedValue(context, originalVal)) + { + return cast<IRGlobalValue>(clonedVal); + } } if(getText(mangledName).Length() == 0) { // If there is no mangled name, then we assume this is a local symbol, // and it can't possibly have multiple declarations. - return cloneGlobalValueImpl(context, originalVal, nullptr); + return cloneGlobalValueImpl(context, originalVal, IROriginalValuesForClone()); } // @@ -5236,7 +5464,7 @@ namespace Slang // This shouldn't happen! SLANG_UNEXPECTED("no matching values registered"); - UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, nullptr)); + UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, IROriginalValuesForClone())); } // We will try to track the "best" declaration we can find. @@ -5256,12 +5484,15 @@ namespace Slang // Check if we've already cloned this value, for the case where // we didn't have an original value (just a name), but we've // now found a representative value. - if( !originalVal && context->getClonedValues().TryGetValue(bestVal, clonedVal) ) + if (!originalVal) { - return (IRGlobalValue*) clonedVal; + if (IRInst* clonedVal = findClonedValue(context, bestVal)) + { + return cast<IRGlobalValue>(clonedVal); + } } - return cloneGlobalValueImpl(context, bestVal, sym); + return cloneGlobalValueImpl(context, bestVal, IROriginalValuesForClone(sym)); } IRGlobalValue* cloneGlobalValueWithMangledName(IRSpecContext* context, Name* mangledName) @@ -5365,11 +5596,6 @@ namespace Slang ProgramLayout* programLayout, SubstitutionSet typeSubst); - RefPtr<GlobalGenericParamSubstitution> createGlobalGenericParamSubstitution( - EntryPointRequest * entryPointRequest, - ProgramLayout * programLayout, - IRSpecContext* context); - struct IRSpecializationState { ProgramLayout* programLayout; @@ -5382,8 +5608,16 @@ namespace Slang IRSharedSpecContext sharedContextStorage; IRSpecContext contextStorage; + IRSpecEnv globalEnv; + IRSharedSpecContext* getSharedContext() { return &sharedContextStorage; } IRSpecContext* getContext() { return &contextStorage; } + + IRSpecializationState() + { + contextStorage.env = &globalEnv; + } + ~IRSpecializationState() { newProgramLayout = nullptr; @@ -5429,19 +5663,27 @@ namespace Slang auto context = state->getContext(); context->shared = sharedContext; context->builder = &sharedContext->builderStorage; - // Create the GlobalGenericParamSubstitution for substituting global generic types - // into user-provided type arguments - auto globalParamSubst = createGlobalGenericParamSubstitution(entryPointRequest, programLayout, context); - context->subst.globalGenParamSubstitutions = globalParamSubst; - - // now specailize the program layout using the substitution - RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout(targetReq, programLayout, context->subst); + // Now specialize the program layout using the substitution + // + // TODO: The specialization of the layout is conceptually an AST-level operations, + // and shouldn't be done here in the IR at all. + // + RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout( + targetReq, + programLayout, + SubstitutionSet(entryPointRequest->globalGenericSubst)); + + // TODO: we need to register the (IR-level) arguments of the global generic parameters as the + // substitutions for the generic parameters in the original IR. + + // applyGlobalGenericParamSubsitution(...); + state->newProgramLayout = newProgramLayout; // Next, we want to optimize lookup for layout infromation - // associated with global declarations, so that we can + // associated with global declarations, so that we can // look things up based on the IR values (using mangled names) auto globalStructLayout = getGlobalStructLayout(newProgramLayout); for (auto globalVarLayout : globalStructLayout->fields) @@ -5453,7 +5695,7 @@ namespace Slang // for now, clone all unreferenced witness tables for (auto sym :context->getSymbols()) { - if (sym.Value->irGlobalValue->op == kIROp_witness_table) + if (sym.Value->irGlobalValue->op == kIROp_WitnessTable) cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); } return state; @@ -5526,6 +5768,20 @@ namespace Slang // it might reference. auto irEntryPoint = specializeIRForEntryPoint(context, entryPointRequest, entryPointLayout); + // HACK: right now the bindings for global generic parameters are coming in + // as part of the original IR module, and we need to make sure these get + // copied over, even if they aren't referenced. + // + for(auto inst : originalIRModule->getGlobalInsts()) + { + auto bindInst = as<IRBindGlobalGenericParam>(inst); + if(!bindInst) + continue; + + cloneValue(context, bindInst); + } + + // TODO: *technically* we should consider the case where // we have global variables with initializers, since // these should get run whether or not the entry point @@ -5551,7 +5807,7 @@ namespace Slang break; } } - + struct IRGenericSpecContext : IRSpecContextBase { IRSpecContextBase* parent = nullptr; @@ -5560,383 +5816,69 @@ namespace Slang // Override the "maybe clone" logic so that we always clone virtual IRInst* maybeCloneValue(IRInst* originalVal) override; - - virtual RefPtr<Type> maybeCloneType(Type* originalType) override; - virtual RefPtr<Val> maybeCloneVal(Val* val) override; }; - // Convert a type-level value into an IR-level equivalent. - IRInst* getIRValue( - IRGenericSpecContext* context, - Val* val) + IRInst* IRGenericSpecContext::maybeCloneValue(IRInst* originalVal) { - if( auto subtypeWitness = dynamic_cast<SubtypeWitness*>(val) ) - { - auto mangledName = context->getModule()->session->getNameObj(getMangledNameForConformanceWitness( - subtypeWitness->sub, - subtypeWitness->sup)); - RefPtr<IRSpecSymbol> symbol; - - if (context->getSymbols().TryGetValue(mangledName, symbol)) - { - // Note: the symbols always come from the source module, - // not the destination module, so we may need to clone - // them if we are doing an initialize specialization pass. - return cloneValue(context, symbol->irGlobalValue); - } - else - { - // we don't have the required witness table yet, - // try to emit a specialize instruction to get one - auto subDeclRef = subtypeWitness->sub->AsDeclRefType(); - auto subDeclRefGen = DeclRef<Decl>(subDeclRef->declRef.decl, - createDefaultSubstitutions(context->builder->getSession(), subDeclRef->declRef.decl)); - - auto genericName = context->getModule()->session->getNameObj(getMangledNameForConformanceWitness( - subDeclRefGen, - subtypeWitness->sup)); - if (context->getSymbols().TryGetValue(genericName, symbol)) - { - auto clonedSymbol = cloneValue(context, symbol->irGlobalValue); - auto specInst = context->builder->emitSpecializeInst(subtypeWitness->sup, clonedSymbol, subDeclRef->declRef); - return specInst; - } - else - { - SLANG_UNEXPECTED("witness table not exist"); - UNREACHABLE_RETURN(nullptr); - } - } - } - else if (auto intVal = dynamic_cast<ConstantIntVal*>(val)) + if (parent) { - return context->builder->getIntValue(context->shared->originalModule->session->getBuiltinType(BaseType::Int), intVal->value); - } - else if (auto proxyVal = dynamic_cast<IRProxyVal*>(val)) - { - // The type-level value actually references an IR-level value, - // so we need to make sure to emit as if we were referencing - // the pointed-to value and not the proxy type-level `Val` - // instead. - - return context->maybeCloneValue(proxyVal->inst.get()); + return parent->maybeCloneValue(originalVal); } else { - SLANG_UNEXPECTED("unimplemented"); - UNREACHABLE_RETURN(nullptr); + return originalVal; } } - IRInst* getSubstValue( - IRGenericSpecContext* context, - DeclRef<Decl> declRef) + // See the work list for the generic spec context with + // every relevant instruction from `inst` through its + // descendents. + void addToSpecializationWorkListRec( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) { - auto subst = context->subst.genericSubstitutions; - SLANG_ASSERT(subst); - auto genericDecl = subst->genericDecl; - - UInt orinaryParamCount = 0; - for( auto mm : genericDecl->Members ) + if(auto genericInst = as<IRGeneric>(inst)) { - if(mm.As<GenericTypeParamDecl>()) - orinaryParamCount++; - else if(mm.As<GenericValueParamDecl>()) - orinaryParamCount++; + // We do *not* consider generics, or instructions nested under them. + return; } - - if( auto constraintDeclRef = declRef.As<GenericTypeConstraintDecl>() ) + else if(auto parentInst = as<IRParentInst>(inst)) { - // We have a constraint, but we need to find its index in the - // argument list of the substitutions. - UInt constraintIndex = 0; - bool found = false; - for( auto cd : genericDecl->getMembersOfType<GenericTypeConstraintDecl>() ) - { - if( cd.Ptr() == constraintDeclRef.getDecl() ) - { - found = true; - break; - } - - constraintIndex++; - } - assert(found); + // For a parent instruction, we will scan through its contents, + // since that will be where the `specialize` instructions are - UInt argIndex = orinaryParamCount + constraintIndex; - assert(argIndex < subst->args.Count()); - - return getIRValue(context, subst->args[argIndex]); - } - else if (auto valDeclRef = declRef.As<GenericValueParamDecl>()) - { - // We have a constraint, but we need to find its index in the - // argument list of the substitutions. - UInt argIdx = 0; - bool found = false; - for (auto cd : genericDecl->Members) + for(auto child : parentInst->children) { - if (cd.Ptr() == valDeclRef.getDecl()) - { - found = true; - break; - } - if (cd.As<GenericTypeParamDecl>()) - argIdx++; - else if (cd.As<GenericValueParamDecl>()) - argIdx++; + addToSpecializationWorkListRec(sharedContext, child); } - assert(found); - - assert(argIdx < subst->args.Count()); - - return getIRValue(context, subst->args[argIdx]); } else { - SLANG_UNEXPECTED("unimplemented"); - UNREACHABLE_RETURN(nullptr); - } - } - - IRInst* IRGenericSpecContext::maybeCloneValue(IRInst* originalVal) - { - switch( originalVal->op ) - { - case kIROp_decl_ref: - { - auto declRefVal = (IRDeclRef*) originalVal; - auto declRef = declRefVal->declRef; - auto genSubst = subst.genericSubstitutions; - SLANG_ASSERT(genSubst); - // We may have a direct reference to one of the parameters - // of the generic we are specializing, and in that case - // we nee to translate it over to the equiavalent of - // the `Val` we have been given. - if(declRef.getDecl()->ParentDecl == genSubst->genericDecl && - (declRef.As<GenericTypeParamDecl>() || declRef.As<GenericValueParamDecl>()|| - declRef.As<GenericTypeConstraintDecl>())) - { - if (auto substVal = getSubstValue(this, declRef)) - return substVal; - } - int diff = 0; - auto substDeclRef = declRefVal->declRef.SubstituteImpl(subst, &diff); - if(!diff) - return originalVal; - - return builder->getDeclRefVal(substDeclRef); - } - break; - - default: - if (parent) - { - return parent->maybeCloneValue(originalVal); - } - else - { - return originalVal; - } - } - } - - RefPtr<Type> IRGenericSpecContext::maybeCloneType(Type* originalType) - { - return originalType->Substitute(subst).As<Type>(); - } - - RefPtr<Val> IRGenericSpecContext::maybeCloneVal(Val * val) - { - return val->Substitute(subst); - } - - // Given a list of substitutions, return the inner-most - // generic substitution in the list, or NULL if there - // are no generic substitutions. - RefPtr<GenericSubstitution> getInnermostGenericSubst( - SubstitutionSet inSubst) - { - return inSubst.genericSubstitutions; - } - - RefPtr<GenericDecl> getInnermostGenericDecl( - Decl* inDecl) - { - auto decl = inDecl; - while( decl ) - { - GenericDecl* genericDecl = dynamic_cast<GenericDecl*>(decl); - if(genericDecl) - return genericDecl; - - decl = decl->ParentDecl; + // Default case: consider this instruction for specialization. + sharedContext->addToWorkList(inst); } - return nullptr; } - // This function takes a list of substitutions that we'd - // like to apply, but which might apply to a different - // declaration in cases where we have got target-specific - // overloads in the mix, and produces a new set of - // substitutiosn without this issue. - RefPtr<GenericSubstitution> cloneSubstitutionsForSpecialization( - IRSharedSpecContext* sharedContext, - RefPtr<GenericSubstitution> oldSubst, - Decl* newDecl) - { - // We will "peel back" layers of substitutions until - // we find our first generic subsitution. - auto oldGenericSubst = oldSubst; - if(!oldGenericSubst) - return nullptr; - - auto innerGenericName = oldGenericSubst->genericDecl->inner->getName(); - - // We will also peel back layers of declarations until - // we find our first generic decl. - GenericDecl* newGenericDecl = nullptr; - - for (Decl* d = newDecl; d; d = d->ParentDecl) - { - if (auto gd = dynamic_cast<GenericDecl*>(d)) - { - if (gd->inner->getName() == innerGenericName) - { - newGenericDecl = gd; - break; - } - } - } - - if( !newGenericDecl ) - { - if(auto gd = dynamic_cast<GenericDecl*>(newDecl)) - { - if( auto ed = gd->inner.As<ExtensionDecl>() ) - { - // TODO: we should confirm that it is an extension for the correct type... - - newGenericDecl = gd; - } - } - } - - SLANG_ASSERT(newGenericDecl); - - RefPtr<GenericSubstitution> newSubst = new GenericSubstitution(); - newSubst->genericDecl = newGenericDecl; - newSubst->args = oldGenericSubst->args; - - newSubst->outer = cloneSubstitutionsForSpecialization( - sharedContext, - oldGenericSubst->outer, - newGenericDecl->ParentDecl); - - return newSubst; - } - - IRFunc* getSpecializedFunc( - IRSharedSpecContext* sharedContext, - IRSpecContextBase* parentContext, - IRFunc* genericFunc, - DeclRef<Decl> specDeclRef); - - IRWitnessTable* specializeWitnessTable( - IRSharedSpecContext* sharedContext, - IRSpecContextBase* parentContext, - IRWitnessTable* originalTable, - DeclRef<Decl> specDeclRef, - IRWitnessTable* dstTable) + IRInst* specializeGeneric( + IRSharedGenericSpecContext* sharedContext, + IRSpecContextBase* parentContext, + IRGeneric* genericVal, + IRSpecialize* specializeInst) { // First, we want to see if an existing specialization // has already been made. To do that we will need to - // compute the mangled name of the specialized function, + // compute the mangled name of the specialized value, // so that we can look for existing declarations. - String specializedMangledName = getMangledNameForConformanceWitness(specDeclRef.Substitute(originalTable->subTypeDeclRef), - specDeclRef.Substitute(originalTable->supTypeDeclRef)); - - if (dstTable && getText(dstTable->mangledName).Length()) - specializedMangledName = getText(dstTable->mangledName); - - // TODO: This is a terrible linear search, and we should - // avoid it by building a dictionary ahead of time, - // as is being done for the `IRSpecContext` used above. - // We can probalby use the same basic context, actually. - if (!dstTable) - { - auto module = sharedContext->module; - for(auto ii : module->getGlobalInsts()) - { - auto gv = as<IRGlobalValue>(ii); - if (!gv) - continue; - - if (getText(gv->mangledName) == specializedMangledName) - return (IRWitnessTable*)gv; - } - } - RefPtr<GenericSubstitution> newSubst = cloneSubstitutionsForSpecialization( - sharedContext, - specDeclRef.substitutions.genericSubstitutions, - originalTable->genericDecl); - - IRGenericSpecContext context; - context.shared = sharedContext; - context.parent = parentContext; - context.builder = &sharedContext->builderStorage; - context.subst = specDeclRef.substitutions; - context.subst.genericSubstitutions = newSubst; - // TODO: other initialization is needed here... - - auto specTable = cloneWitnessTableWithoutRegistering(&context, originalTable, dstTable); - - // Set up the clone to recognize that it is no longer generic - specTable->mangledName = context.getModule()->session->getNameObj(specializedMangledName); - specTable->genericDecl = nullptr; - - // Specialization of witness tables should trigger cascading specializations - // of involved functions. - for (auto entry : specTable->getEntries()) - { - if (entry->satisfyingVal.get()->op == kIROp_Func) - { - IRFunc* func = (IRFunc*)entry->satisfyingVal.get(); - auto specFunc = getSpecializedFunc(sharedContext, parentContext, func, specDeclRef); - entry->satisfyingVal.set(specFunc); - insertGlobalValueSymbol(sharedContext, specFunc); - } - - } - // We also need to make sure that we register this specialized - // function under its mangled name, so that later lookup - // steps will find it. - insertGlobalValueSymbol(sharedContext, specTable); - - return specTable; - } - - IRFunc* getSpecializedFunc( - IRSharedSpecContext* sharedContext, - IRSpecContextBase* parentContext, - IRFunc* genericFunc, - DeclRef<Decl> specDeclRef) - { - // First, we want to see if an existing specialization - // has already been made. To do that we will need to - // compute the mangled name of the specialized function, - // so that we can look for existing declarations. - String specMangledName; - if (genericFunc->getGenericDecl() == specDeclRef.decl) - specMangledName = getMangledName(specDeclRef); - else - specMangledName = mangleSpecializedFuncName(getText(genericFunc->mangledName), specDeclRef.substitutions); + String specMangledName = mangleSpecializedFuncName(getText(genericVal->mangledName), specializeInst); auto specMangledNameObj = sharedContext->module->session->getNameObj(specMangledName); + + // Now look up an existing symbol with a matching name RefPtr<IRSpecSymbol> symb; if (sharedContext->symbols.TryGetValue(specMangledNameObj, symb)) { - return (IRFunc*)(symb->irGlobalValue); + return symb->irGlobalValue; } + // TODO: This is a terrible linear search, and we should // avoid it by building a dictionary ahead of time, // as is being done for the `IRSpecContext` used above. @@ -5948,104 +5890,285 @@ namespace Slang continue; if (gv->mangledName == specMangledNameObj) - return (IRFunc*) gv; + return gv; } // If we get to this point, then we need to construct a - // new `IRFunc` to represent the result of specialization. + // new IR value to represent the result of specialization. - // The substitutions we are applying might have been created - // using a different overload of a target-specific function, - // so we need to create a dummy substitution here, to make - // sure it used the correct generic. - RefPtr<GenericSubstitution> newSubst = cloneSubstitutionsForSpecialization( - sharedContext, - specDeclRef.substitutions.genericSubstitutions, - genericFunc->getGenericDecl()); + // We need to establish a new mapping from inst->inst to + // handle the specialization, because we don't want the + // clones we register in this pass to cause confusion + // in later steps that might clone the same code. + + IRSpecEnv env; + env.parent = &sharedContext->globalEnv; + if (parentContext) + { + env.parent = parentContext->getEnv(); + } - if (!newSubst) - return genericFunc; + // The result of specialization should be inserted + // into the global scope, at the same location as + // the original generic. + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(genericVal); IRGenericSpecContext context; context.shared = sharedContext; context.parent = parentContext; - context.builder = &sharedContext->builderStorage; - context.subst = specDeclRef.substitutions; - context.subst.genericSubstitutions = newSubst; + context.builder = builder; + context.env = &env; - // TODO: other initialization is needed here... + // Register the arguments of the `specialize` instruction to be used + // as the "cloned" value for each of the parameters of the generic. + // + UInt argCounter = 0; + for (auto param = genericVal->getFirstParam(); param; param = param->getNextParam()) + { + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < specializeInst->getArgCount()); - auto specFunc = cloneSimpleFuncWithoutRegistering(&context, genericFunc); + IRInst* arg = specializeInst->getArg(argIndex); - specFunc->mangledName = context.getModule()->session->getNameObj(specMangledName); - - // reduce specialized generic level by 1 - if (specFunc->specializedGenericLevel >= 0) - specFunc->specializedGenericLevel--; + registerClonedValue(&context, arg, param); + } - // Put the function into the global sequence right after - // the function it specializes. - // - // TODO: This shouldn't be needed, if we introduce a sorting - // step before we emit code. - //specFunc->removeFromParent(); - //specFunc->insertAfter(genericFunc); + // Okay, now we want to run through the body of the generic + // and clone stuff into the parent scope (which had + // better be the global scope). + for (auto bb : genericVal->getBlocks()) + { + // We expect a generic to only ever contain a single block. + SLANG_ASSERT(bb == genericVal->getFirstBlock()); - // At this point we've created a new non-generic function, - // which means we should add it to our work list for - // subsequent processing. - if (specFunc->specializedGenericLevel == -1) - sharedContext->workList.Add(specFunc); + for (auto ii : bb->getChildren()) + { + // Skip parameters, since they were handled earlier. + if (auto param = as<IRParam>(ii)) + continue; + + // The last block of the generic is expected to end with + // a `return` instruction for the specialized value that + // comes out of the abstraction. + // + // We thus use that cloned value as the result of the + // specialization step. + if (auto returnValInst = as<IRReturnVal>(ii)) + { + auto clonedResult = cloneValue(&context, returnValInst->getVal()); + if (auto clonedGlobalValue = as<IRGlobalValue>(clonedResult)) + { + clonedGlobalValue->mangledName = specMangledNameObj; + + // TODO: create a symbol for it and add it to the map. + } + + return clonedResult; + } - // We also need to make sure that we register this specialized - // function under its mangled name, so that later lookup - // steps will find it. - insertGlobalValueSymbol(sharedContext, specFunc); + // Otherwise, clone the instruction into the global scope + IRInst* clonedInst = cloneInst(&context, context.builder, ii); - return specFunc; + // Now that we've cloned the instruction to a location outside + // of a generic, we should consider whether it can now be specialized. + addToSpecializationWorkListRec(sharedContext, clonedInst); + } + } + + // If we reach this point, something went wrong, because we + // never encountered a `return` inside the body of the generic. + SLANG_UNEXPECTED("no return from generic"); + UNREACHABLE_RETURN(nullptr); } // Find the value in the given witness table that // satisfies the given requirement (or return // null if not found). IRInst* findWitnessVal( - IRWitnessTable* witnessTable, - DeclRef<Decl> const& requirementDeclRef) + IRWitnessTable* witnessTable, + IRInst* requirementKey) { // For now we will do a dumb linear search for( auto entry : witnessTable->getEntries() ) { - // We expect the key on the entry to be a decl-ref, - // but lets go ahead and check, just to be sure. - auto requirementKey = entry->requirementKey.get(); - if(requirementKey->op != kIROp_decl_ref) + // If the keys matched, then we use the value from this entry. + if (requirementKey == entry->requirementKey.get()) + { + auto satisfyingVal = entry->satisfyingVal.get(); + return satisfyingVal; + } + } + + // No matching entry found. + return nullptr; + } + + static bool canSpecializeGeneric( + IRGeneric* generic) + { + IRGeneric* g = generic; + for(;;) + { + auto val = findGenericReturnVal(g); + if(!val) + return false; + + if (auto nestedGeneric = as<IRGeneric>(val)) + { + // The outer generic returns an *inner* generic + // (so that multiple calls to `specialize` are + // needed to resolve it). We should look at + // what the nested generic returns to figure + // out whether specialization is allowed. + g = nestedGeneric; continue; - auto keyDeclRef = ((IRDeclRef*) requirementKey)->declRef; + } - // If the keys don't match, continue with the next entry. - if (!keyDeclRef.Equals(requirementDeclRef)) + // We've found the leaf value that will be produced after + // all of the specialization is done. Now we want to know + // if that is a value suitable for actually specializing + + if (auto globalValue = as<IRGlobalValue>(val)) { - // requirementDeclRef may be pointing to the inner decl of a generic decl - // in this case we compare keyDeclRef against the parent decl of requiredDeclRef - if (auto genRequiredDeclRef = requirementDeclRef.GetParent().As<GenericDecl>()) + if (isDefinition(globalValue)) + return true; + return false; + } + else + { + // There might be other cases with a declaration-vs-definition + // thing that we need to handle. + + return true; + } + } + } + + // Add any instruction that uses `inst` to the work list, + // so that it can be evaluated (or re-evaluated) for specialization. + void addUsesToWorkList( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + for(auto u = inst->firstUse; u; u = u->nextUse) + { + sharedContext->addToWorkList(u->getUser()); + } + } + + void specializeGenericsForInst( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + switch(inst->op) + { + default: + // The default behavior is to do nothing. + // An instruction is specialize-able once its operands + // are specialized, and after that it is also safe + // to consider the instruction specialized. + break; + + case kIROp_Specialize: + { + // We have a `specialize` instruction, so lets see + // whether we have an opportunity to perform the + // specialization here and now. + IRSpecialize* specInst = cast<IRSpecialize>(inst); + + // Look at the base of the `specialize`, and see if + // it directly names a generic, so that we can apply + // specialization here and now. + auto baseVal = specInst->getBase(); + if(auto genericVal = as<IRGeneric>(baseVal)) { - if (!keyDeclRef.Equals(genRequiredDeclRef)) + if (canSpecializeGeneric(genericVal)) { - continue; + // Okay, we have a candidate for specialization here. + // + // We will apply the specialization logic to the body of the generic, + // which will yield, e.g., a specialized `IRFunc`. + // + auto specializedVal = specializeGeneric(sharedContext, nullptr, genericVal, specInst); + // + // Then we will replace the use sites for the `specialize` + // instruction with uses of the specialized value. + // + addUsesToWorkList(sharedContext, specInst); + specInst->replaceUsesWith(specializedVal); + specInst->removeAndDeallocate(); } } - else - continue; } + break; + + case kIROp_lookup_interface_method: + { + // We have a `lookup_interface_method` instruction, + // so let's see whether it is a lookup in a known + // witness table. + IRLookupWitnessMethod* lookupInst = cast<IRLookupWitnessMethod>(inst); + + // We only want to deal with the case where the witness-table + // argument points to a concrete global table (and not, e.g., a + // `specialize` instruction that will yield a table) + auto witnessTable = as<IRWitnessTable>(lookupInst->witnessTable.get()); + if(!witnessTable) + break; + + // Use the witness table to look up the value that + // satisfies the requirement. + auto requirementKey = lookupInst->getRequirementKey(); + auto satisfyingVal = findWitnessVal(witnessTable, requirementKey); + // We expect to always find something, but lets just + // be careful here. + if(!satisfyingVal) + break; - // If the keys matched, then we use the value from - // this entry. - auto satisfyingVal = entry->satisfyingVal.get(); - return satisfyingVal; + // If we get through all of the above checks, then we + // have a (more) concrete method that implements the interface, + // and so we should dispatch to that directly, rather than + // use the `lookup_interface_method` instruction. + addUsesToWorkList(sharedContext, lookupInst); + lookupInst->replaceUsesWith(satisfyingVal); + lookupInst->removeAndDeallocate(); + } + break; } + } - // No matching entry found. - return nullptr; + static bool isInstSpecialized( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + // If an instruction is still on our work list, then + // it isn't specialized, and conversely we say that + // if it *isn't* on the work list, it must be specialized. + // + // Note: if we end up with bugs in this logic, we could + // maintain an explicit set of specialized insts instead. + // + return !sharedContext->workListSet.Contains(inst); + } + + static bool canSpecializeInst( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + // We can specialize an instruction once all its + // operands are specialized. + + UInt operandCount = inst->getOperandCount(); + for(UInt ii = 0; ii < operandCount; ++ii) + { + IRInst* operand = inst->getOperand(ii); + if(!isInstSpecialized(sharedContext, operand)) + return false; + } + return true; } // Go through the code in the module and try to identify @@ -6056,7 +6179,7 @@ namespace Slang IRModule* module, CodeGenTarget target) { - IRSharedSpecContext sharedContextStorage; + IRSharedGenericSpecContext sharedContextStorage; auto sharedContext = &sharedContextStorage; initializeSharedSpecContext( @@ -6066,351 +6189,127 @@ namespace Slang module, target); - // Our goal here is to find `specialize` instructions that - // can be replaced with references to a suitably sepcialized - // funciton. As a simplification, we will only consider `specialize` - // calls that are inside of non-generic functions, since we assume - // that these will allow us to fully specialize the referenced - // function. - // - // We start by building up a work list of non-generic functions. - for(auto ii : module->getGlobalInsts()) - { - auto gv = as<IRGlobalValue>(ii); - if (!gv) - continue; + auto moduleInst = module->getModuleInst(); - // Is it a function? If not, skip. - if(gv->op != kIROp_Func) + // First things first, let's deal with any bindings for global generic parameters. + for(auto inst : moduleInst->getChildren()) + { + auto bindInst = as<IRBindGlobalGenericParam>(inst); + if(!bindInst) continue; - auto func = (IRFunc*) gv; - // Is it generic? If so, skip. - if(func->getGenericDecl()) - continue; + auto param = bindInst->getParam(); + auto val = bindInst->getVal(); - sharedContext->workList.Add(func); + param->replaceUsesWith(val); } - - // Build dictionary for witness tables - Dictionary<Name*, IRWitnessTable*> witnessTables; - for(auto ii : module->getGlobalInsts()) { - auto gv = as<IRGlobalValue>(ii); - if (!gv) - continue; - - if (gv->op == kIROp_witness_table) - witnessTables.AddIfNotExists(gv->mangledName, (IRWitnessTable*)gv); - } - - // Now that we have our work list, we are going to - // process it until it goes empty. Along the way - // we may specialize a function and thus create - // a new non-generic function, and in that case - // we will add the new function to the work list. - auto& workList = sharedContext->workList; - while( auto count = workList.Count() ) - { - // We will process the last entry in the - // work list, which amounts to treating - // it like a stack when we have recursive - // specialization to perform. - auto func = workList[count-1]; - workList.RemoveAt(count-1); - - // We are going to go ahead and walk through - // all the instructions in this function, - // and look for `specialize` operations. - for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) + // Now we will do a second pass to clean up the + // generic parameters and their bindings. + IRInst* next = nullptr; + for(auto inst = moduleInst->getFirstChild(); inst; inst = next) { - // We need to be careful when iterating over the instructions, - // because we might end up removing the "current" instruction, - // so that accessing `ii->next` would crash. - IRInst* nextInst = nullptr; - for( auto ii = bb->getFirstInst(); ii; ii = nextInst ) - { - nextInst = ii->getNextInst(); - - // We want to handle both `specialize` instructions, - // which trigger specialization, and also `lookup_interface_method` - // instructions, which may allow us to "de-virtualize" - // calls. - - switch( ii->op ) - { - default: - // Most instructions are ones we don't care about here. - continue; - - case kIROp_specialize: - { - // We have a `specialize` instruction, so lets see - // whether we have an opportunity to perform the - // specialization here and now. - IRSpecialize* specInst = (IRSpecialize*) ii; - - // Now we extract the specialized decl-ref that will - // tell us how to specialize things. - auto specDeclRefVal = (IRDeclRef*)specInst->specDeclRefVal.get(); - auto specDeclRef = specDeclRefVal->declRef; - - // We need to specialize functions and witness tables - auto genericVal = specInst->genericVal.get(); - if (genericVal->op == kIROp_Func) - { - auto genericFunc = (IRFunc*)genericVal; - if (!genericFunc->getGenericDecl()) - continue; - - // Okay, we have a candidate for specialization here. - // - // We will first find or construct a specialized version - // of the callee funciton/ - auto specFunc = getSpecializedFunc(sharedContext, nullptr, genericFunc, specDeclRef); - // - // Then we will replace the use sites for the `specialize` - // instruction with uses of the specialized function. - // - specInst->replaceUsesWith(specFunc); - - specInst->removeAndDeallocate(); - } - else if (genericVal->op == kIROp_witness_table) - { - // specialize a witness table - auto originalTable = (IRWitnessTable*)genericVal; - auto specWitnessTable = specializeWitnessTable(sharedContext, nullptr, originalTable, specDeclRef, nullptr); - witnessTables.AddIfNotExists(specWitnessTable->mangledName, specWitnessTable); - specInst->replaceUsesWith(specWitnessTable); - specInst->removeAndDeallocate(); - } - } - break; - case kIROp_lookup_witness_table: - { - // try find concrete witness table from global scope - IRLookupWitnessTable* lookupInst = (IRLookupWitnessTable*)ii; - IRWitnessTable* witnessTable = nullptr; - auto srcDeclRef = ((IRDeclRef*)lookupInst->sourceType.get())->declRef; - auto interfaceDeclRef = ((IRDeclRef*)lookupInst->interfaceType.get())->declRef; - auto mangledName = module->session->getNameObj(getMangledNameForConformanceWitness(srcDeclRef, interfaceDeclRef)); - witnessTables.TryGetValue(mangledName, witnessTable); - - if (!witnessTable) - { - // try specialize the witness table - auto genDeclRef = srcDeclRef; - genDeclRef.substitutions = createDefaultSubstitutions(module->session, genDeclRef.decl); - auto genName = module->session->getNameObj(getMangledNameForConformanceWitness(genDeclRef, interfaceDeclRef)); - IRWitnessTable* genTable = nullptr; - if (witnessTables.TryGetValue(genName, genTable)) - { - witnessTable = specializeWitnessTable(sharedContext, nullptr, genTable, srcDeclRef, nullptr); - witnessTables.AddIfNotExists(witnessTable->mangledName, witnessTable); - } - } - if (witnessTable) - { - lookupInst->replaceUsesWith(witnessTable); - lookupInst->removeAndDeallocate(); - } - } - break; - case kIROp_lookup_interface_method: - { - // We have a `lookup_interface_method` instruction, - // so let's see whether it is a lookup in a known - // witness table. - IRLookupWitnessMethod* lookupInst = (IRLookupWitnessMethod*) ii; - - // We only want to deal with the case where the witness-table - // argument points to a concrete global table. - auto witnessTableArg = lookupInst->witnessTable.get(); - if(witnessTableArg->op != kIROp_witness_table) - continue; - IRWitnessTable* witnessTable = (IRWitnessTable*)witnessTableArg; - - // We also need to be sure that the requirement we - // are trying to look up is identified via a decl-ref: - auto requirementArg = lookupInst->requirementDeclRef.get(); - if(requirementArg->op != kIROp_decl_ref) - continue; - auto requirementDeclRef = ((IRDeclRef*) requirementArg)->declRef; - - // Use the witness table to look up the value that - // satisfies the requirement. - auto satisfyingVal = findWitnessVal(witnessTable, requirementDeclRef); - // We expect to always find something, but lets just - // be careful here. - if(!satisfyingVal) - continue; - - // If we get through all of the above checks, then we - // have a (more) concrete method that implements the interface, - // and so we should dispatch to that directly, rather than - // use the `lookup_interface_method` instruction. - lookupInst->replaceUsesWith(satisfyingVal); - lookupInst->removeAndDeallocate(); - } - break; - } + next = inst->getNextInst(); + switch(inst->op) + { + default: + break; - // We only care about `specialize` instructions. - if(ii->op != kIROp_specialize) - continue; - + case kIROp_GlobalGenericParam: + case kIROp_BindGlobalGenericParam: + // A "bind" instruction should have no uses in the + // first place, and all the global generic parameters + // should have had their uses replaced. + SLANG_ASSERT(!inst->firstUse); + inst->removeAndDeallocate(); + break; } } } - // Once the work list has gone dry, we should have the invariant - // that there are no `specialize` instructions inside of non-generic - // functions that in turn reference a generic function. - } - - RefPtr<GlobalGenericParamSubstitution> createGlobalGenericParamSubstitution( - EntryPointRequest * entryPointRequest, - ProgramLayout * programLayout, - IRSpecContext* context) - { - RefPtr<GlobalGenericParamSubstitution> globalParamSubst; - GlobalGenericParamSubstitution * curTailSubst = nullptr; - - // Because we can't currently put `specialize` instructions inside - // witness tables, or at the global scope, we will track a set of - // witness tables that we need to clone, and then specialize - // from the original module(s) to get what we need. + // Our goal here is to find `specialize` instructions that + // can be replaced with references to, e.g., a suitably + // specialized function, and to resolve any `lookup_interface_method` + // instructions to the concrete value fetched from a witness + // table. + // + // We need to be careful of a few things: + // + // * It would not in general make sense to consider specialize-able + // instructions under an `IRGeneric`, since that could mean "specialziing" + // code to parameter values that are still unknown. + // + // * We *also* need to be careful not to specialize something when one + // or more of its inputs is also a `specialize` or `lookup_interface_method` + // instruction, because then we'd be propagating through non-concrete + // values. + // + // The approach we use here is to build a work list of instructions + // that *can* become fully specialized, but aren't yet. Any + // instruction on the work list will be considered to be "unspecialized" + // and any instruction not on the work list is considered specialized. + // + // We will start by recursively walking all the instructions to add + // the appropriate ones to our work list: + // + addToSpecializationWorkListRec(sharedContext, moduleInst); - struct WitnessTableCloneWorkItem + // Now we are going to repeatedly walk our work list, and filter + // it to create a new work list. + List<IRInst*> workListCopy; + for(;;) { - IRWitnessTable* dstTable; - IRWitnessTable* originalTable; - }; - List<WitnessTableCloneWorkItem> witnessTablesToClone; + // Swap out the work list on the context so we can + // process it here without worrying about concurrent + // modifications. + workListCopy.Clear(); + workListCopy.SwapWith(sharedContext->workList); - struct WitnessTableSpecializationWorkItem - { - IRWitnessTable* dstTable; - IRWitnessTable* srcTable; - DeclRef<Decl> specDeclRef; - }; - List<WitnessTableSpecializationWorkItem> witnessTablesToSpecailize; - - Dictionary<Name*, IRWitnessTable*> witnessTablesByName; - auto namePool = entryPointRequest->compileRequest->getNamePool(); - - for (auto param : programLayout->globalGenericParams) - { - auto paramSubst = new GlobalGenericParamSubstitution(); - if (!globalParamSubst) - globalParamSubst = paramSubst; - if (curTailSubst) - curTailSubst->outer = paramSubst; - curTailSubst = paramSubst; - paramSubst->paramDecl = param->decl; - SLANG_ASSERT((UInt)param->index < entryPointRequest->genericParameterTypes.Count()); - paramSubst->actualType = entryPointRequest->genericParameterTypes[param->index]; - // find witness tables - for (auto witness : entryPointRequest->genericParameterWitnesses) + if(workListCopy.Count() == 0) + break; + + for(auto inst : workListCopy) { - if (auto subtypeWitness = witness.As<SubtypeWitness>()) + // We need to check whether it is possible to specialize + // the instruction yet (it might not be because its + // operands haven't been specialized) + if(!canSpecializeInst(sharedContext, inst)) { - if (subtypeWitness->sub->EqualsVal(paramSubst->actualType)) - { - auto witnessTableName = namePool->getName(getMangledNameForConformanceWitness(subtypeWitness->sub, subtypeWitness->sup)); - auto findWitnessTableByName = [&](Name* name) -> IRWitnessTable* - { - RefPtr<IRSpecSymbol> symbol; - if (!context->getSymbols().TryGetValue(name, symbol)) - return nullptr; - - return (IRWitnessTable*) symbol->irGlobalValue; - }; - - auto findCloneOfWitnessTableByName = [&](Name* name) -> IRWitnessTable* - { - IRWitnessTable* clonedTable = nullptr; - if (witnessTablesByName.TryGetValue(name, clonedTable)) - return clonedTable; - - IRWitnessTable* originalTable = findWitnessTableByName(name); - if (!originalTable) - return nullptr; - - clonedTable = context->builder->createWitnessTable(); - - WitnessTableCloneWorkItem cloneWorkItem; - cloneWorkItem.originalTable = originalTable; - cloneWorkItem.dstTable = clonedTable; - witnessTablesToClone.Add(cloneWorkItem); - - return clonedTable; - }; - - // First look for a non-generic witness table that matches - auto table = findCloneOfWitnessTableByName(witnessTableName); - if (!table) - { - // If we didn't find a non-generic table, then maybe we are looking at - // a specialization of a generic witness table. - if (auto subDeclRefType = subtypeWitness->sub.As<DeclRefType>()) - { - auto defaultSubst = createDefaultSubstitutions(entryPointRequest->compileRequest->mSession, subDeclRefType->declRef.getDecl()); - auto genericWitnessTableName = namePool->getName( - getMangledNameForConformanceWitness(DeclRef<Decl>(subDeclRefType->declRef.getDecl(), defaultSubst), subtypeWitness->sup)); - - IRWitnessTable* genericTable = findCloneOfWitnessTableByName(genericWitnessTableName); - SLANG_ASSERT(genericTable); - - WitnessTableSpecializationWorkItem specializeWorkItem; - specializeWorkItem.srcTable = genericTable; - specializeWorkItem.dstTable = context->builder->createWitnessTable(); - specializeWorkItem.dstTable->mangledName = context->getModule()->session->getNameObj(getMangledNameForConformanceWitness(subDeclRefType->declRef, subtypeWitness->sup)); - specializeWorkItem.specDeclRef = subDeclRefType->declRef; - - witnessTablesToSpecailize.Add(specializeWorkItem); - table = specializeWorkItem.dstTable; - } - } - // We expect to find the table no matter what. - SLANG_ASSERT(table); + // Put it back on the fresh work list, so that + // we can re-consider it in another iteration. + sharedContext->workList.Add(inst); + } + else + { + // Okay, perform any specialization step on this + // instruction that makes sense (which might be + // doing nothing). + specializeGenericsForInst(sharedContext, inst); - IRProxyVal * tableVal = new IRProxyVal(); - tableVal->inst.init(nullptr, table); - paramSubst->witnessTables.Add(KeyValuePair<RefPtr<Type>, RefPtr<Val>>(subtypeWitness->sup, tableVal)); - } + // Remove the instruction from consideration. + sharedContext->workListSet.Remove(inst); } } } - for (auto workItem : witnessTablesToClone) - { - cloneWitnessTableWithoutRegistering( - context, - workItem.originalTable, - workItem.dstTable); - } - - for (auto workItem : witnessTablesToSpecailize) - { - int diff = 0; - specializeWitnessTable( - context->shared, - context, - workItem.srcTable, - workItem.specDeclRef.SubstituteImpl(SubstitutionSet(nullptr, nullptr, globalParamSubst), &diff), - workItem.dstTable); - } + // Once the work list has gone dry, we should have the invariant + // that there are no `specialize` instructions inside of non-generic + // functions that in turn reference a generic function, *except* + // in the case where that generic is for a builtin function, in + // which case we wouldn't want to specialize it anyway. + } - return globalParamSubst; + void applyGlobalGenericParamSubstitution( + IRSpecContext* /*context*/) + { + // TODO: we need to figure out how to apply this } - + void markConstExpr( - Session* session, - IRInst* irValue) + IRBuilder* builder, + IRInst* irValue) { // We will take an IR value with type `T`, // and turn it into one with type `@ConstExpr T`. @@ -6418,6 +6317,9 @@ namespace Slang // TODO: need to be careful if the value already has a rate // qualifier set. - irValue->type = session->getConstExprType(irValue->getDataType()); + irValue->setFullType( + builder->getRateQualifiedType( + builder->getConstExprRate(), + irValue->getDataType())); } } |
