diff options
Diffstat (limited to 'source/slang/ir.cpp')
| -rw-r--r-- | source/slang/ir.cpp | 3358 |
1 files changed, 1630 insertions, 1728 deletions
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 75f43453a..2615c1c07 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -14,38 +14,34 @@ namespace Slang Name* mangledName, IRGlobalValue* originalVal); - - static const IROpInfo kIROpInfos[] = + struct IROpMapEntry { -#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ - { #MNEMONIC, ARG_COUNT, FLAGS, }, -#include "ir-inst-defs.h" + IROp op; + IROpInfo info; }; - // - - IROp findIROp(char const* name) + // TODO: We should ideally be speeding up the name->inst + // mapping by using a dictionary, or even by pre-computing + // a hash table to be stored as a `static const` array. + static const IROpMapEntry kIROps[] = { - // TODO: need to make this faster by using a dictionary... - - static const struct { - char const* mnemonic; - IROp op; - } kOps[] = { + { kIROp_Invalid, { "invalid", 0, 0 } }, #define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ - { #MNEMONIC, kIROp_##ID }, - + { kIROp_##ID, { #MNEMONIC, ARG_COUNT, FLAGS, } }, #define PSEUDO_INST(ID) \ - { #ID, kIRPseudoOp_##ID }, - + { kIRPseudoOp_##ID, { #ID, 0, 0 } }, #include "ir-inst-defs.h" - }; + }; + + // - for (auto ee : kOps) + IROp findIROp(char const* name) + { + for (auto ee : kIROps) { - if (strcmp(name, ee.mnemonic) == 0) + if (strcmp(name, ee.info.name) == 0) return ee.op; } @@ -54,7 +50,13 @@ namespace Slang IROpInfo getIROpInfo(IROp op) { - return kIROpInfos[op]; + for (auto ee : kIROps) + { + if (ee.op == op) + return ee.info; + } + + return kIROps[0].info; } // @@ -65,7 +67,6 @@ namespace Slang auto uv = this->usedValue; if(!uv) { - assert(!user); assert(!nextUse); assert(!prevLink); return; @@ -160,6 +161,22 @@ namespace Slang return nullptr; } + // IRConstant + + IRIntegerValue GetIntVal(IRInst* inst) + { + switch (inst->op) + { + default: + SLANG_UNEXPECTED("needed a known integer value"); + UNREACHABLE_RETURN(0); + + case kIROp_IntLit: + return ((IRConstant*)inst)->u.intVal; + break; + } + } + // IRParam IRParam* IRParam::getNextParam() @@ -167,6 +184,17 @@ namespace Slang return as<IRParam>(getNextInst()); } + // IRArrayTypeBase + + IRInst* IRArrayTypeBase::getElementCount() + { + if (auto arrayType = as<IRArrayType>(this)) + return arrayType->getElementCount(); + + return nullptr; + } + + // IRBlock IRParam* IRBlock::getLastParam() @@ -416,13 +444,7 @@ namespace Slang return (IRBlock*)use->get(); } - // IRFunc - - IRType* IRFunc::getResultType() { return getType()->getResultType(); } - UInt IRFunc::getParamCount() { return getType()->getParamCount(); } - IRType* IRFunc::getParamType(UInt index) { return getType()->getParamType(index); } - - IRParam* IRFunc::getFirstParam() + IRParam* IRGlobalValueWithParams::getFirstParam() { auto entryBlock = getFirstBlock(); if(!entryBlock) return nullptr; @@ -430,6 +452,12 @@ namespace Slang return entryBlock->getFirstParam(); } + // IRFunc + + IRType* IRFunc::getResultType() { return getDataType()->getResultType(); } + UInt IRFunc::getParamCount() { return getDataType()->getParamCount(); } + IRType* IRFunc::getParamType(UInt index) { return getDataType()->getParamType(index); } + void IRGlobalValueWithCode::addBlock(IRBlock* block) { block->insertAtEnd(this); @@ -589,7 +617,7 @@ namespace Slang { if (rr == leftNonBlock) { - SLANG_ASSERT(!parentNonBlock); + SLANG_ASSERT(!parentNonBlock || parentNonBlock == leftNonBlock); parentNonBlock = rightNonBlock; break; } @@ -677,6 +705,9 @@ namespace Slang for (UInt ii = 0; ii < operandCount; ++ii) { auto operand = inst->getOperand(ii); + if (!operand) + continue; + auto operandParent = operand->getParent(); parent = mergeCandidateParentsForHoistableInst(parent, operandParent); @@ -727,22 +758,6 @@ namespace Slang value->sourceLoc = sourceLocInfo->sourceLoc; } - template<typename T> - static T* createValue( - IRBuilder* builder, - IROp op, - IRType* type) - { - assert(builder->getModule()); - T* value = (T*)builder->getModule()->memoryPool.allocZero(sizeof(T)); - new(value)T(); - value->op = op; - value->type = type; - builder->getModule()->irObjectsToFree.Add(value); - return value; - } - - // Create an IR instruction/value and initialize it. // // In this case `argCount` and `args` represnt the @@ -752,23 +767,39 @@ namespace Slang static T* createInstImpl( IRModule* module, IRBuilder* builder, - UInt size, IROp op, IRType* type, UInt fixedArgCount, IRInst* const* fixedArgs, - UInt varArgCount = 0, - IRInst* const* varArgs = nullptr) + UInt varArgListCount, + UInt const* listArgCounts, + IRInst* const* const* listArgs) { + UInt varArgCount = 0; + for (UInt ii = 0; ii < varArgListCount; ++ii) + { + varArgCount += listArgCounts[ii]; + } + + UInt size = sizeof(IRInst) + (fixedArgCount + varArgCount) * sizeof(IRUse); + if (sizeof(T) > size) + { + size = sizeof(T); + } + assert(module); T* inst = (T*)module->memoryPool.allocZero(size); new(inst)T(); + inst->operandCount = (uint32_t)(fixedArgCount + varArgCount); inst->op = op; - inst->type = type; + if (type) + { + inst->typeUse.init(inst, type); + } maybeSetSourceLoc(builder, inst); @@ -783,13 +814,21 @@ namespace Slang operand++; } - for( UInt aa = 0; aa < varArgCount; ++aa ) + for (UInt ii = 0; ii < varArgListCount; ++ii) { - if (varArgs) + UInt listArgCount = listArgCounts[ii]; + for (UInt jj = 0; jj < listArgCount; ++jj) { - operand->init(inst, varArgs[aa]); + if (listArgs[ii]) + { + operand->init(inst, listArgs[ii][jj]); + } + else + { + operand->init(inst, nullptr); + } + operand++; } - operand++; } module->irObjectsToFree.Add(inst); return inst; @@ -798,24 +837,46 @@ namespace Slang template<typename T> static T* createInstImpl( IRBuilder* builder, - UInt size, IROp op, IRType* type, UInt fixedArgCount, IRInst* const* fixedArgs, - UInt varArgCount = 0, + UInt varArgCount = 0, IRInst* const* varArgs = nullptr) { return createInstImpl<T>( builder->getModule(), builder, - size, op, type, fixedArgCount, fixedArgs, - varArgCount, - varArgs); + 1, + &varArgCount, + &varArgs); + } + + template<typename T> + static T* createInstImpl( + IRBuilder* builder, + IROp op, + IRType* type, + UInt fixedArgCount, + IRInst* const* fixedArgs, + UInt varArgListCount, + UInt const* listArgCount, + IRInst* const* const* listArgs) + { + return createInstImpl<T>( + builder->getModule(), + builder, + op, + type, + fixedArgCount, + fixedArgs, + varArgListCount, + listArgCount, + listArgs); } template<typename T> @@ -828,7 +889,6 @@ namespace Slang { return createInstImpl<T>( builder, - sizeof(T), op, type, argCount, @@ -843,7 +903,6 @@ namespace Slang { return createInstImpl<T>( builder, - sizeof(T), op, type, 0, @@ -859,7 +918,6 @@ namespace Slang { return createInstImpl<T>( builder, - sizeof(T), op, type, 1, @@ -877,7 +935,6 @@ namespace Slang IRInst* args[] = { arg1, arg2 }; return createInstImpl<T>( builder, - sizeof(T), op, type, 2, @@ -894,7 +951,6 @@ namespace Slang { return createInstImpl<T>( builder, - sizeof(T) + argCount * sizeof(IRUse), op, type, argCount, @@ -913,7 +969,6 @@ namespace Slang { return createInstImpl<T>( builder, - sizeof(T) + varArgCount * sizeof(IRUse), op, type, fixedArgCount, @@ -936,7 +991,6 @@ namespace Slang return createInstImpl<T>( builder, - sizeof(T) + varArgCount * sizeof(IRUse), op, type, fixedArgCount, @@ -949,7 +1003,7 @@ namespace Slang bool operator==(IRInstKey const& left, IRInstKey const& right) { if(left.inst->op != right.inst->op) return false; - if(left.inst->parent != right.inst->parent) return false; + if(left.inst->getFullType() != right.inst->getFullType()) return false; if(left.inst->operandCount != right.inst->operandCount) return false; auto argCount = left.inst->operandCount; @@ -967,7 +1021,7 @@ namespace Slang int IRInstKey::GetHashCode() { auto code = Slang::GetHashCode(inst->op); - code = combineHash(code, Slang::GetHashCode(inst->parent)); + code = combineHash(code, Slang::GetHashCode(inst->getFullType())); code = combineHash(code, Slang::GetHashCode(inst->getOperandCount())); auto argCount = inst->getOperandCount(); @@ -984,7 +1038,7 @@ namespace Slang bool operator==(IRConstantKey const& left, IRConstantKey const& right) { if(left.inst->op != right.inst->op) return false; - if(left.inst->type != right.inst->type) return false; + if(left.inst->getFullType() != right.inst->getFullType()) return false; if(left.inst->u.ptrData[0] != right.inst->u.ptrData[0]) return false; if(left.inst->u.ptrData[1] != right.inst->u.ptrData[1]) return false; return true; @@ -993,7 +1047,7 @@ namespace Slang int IRConstantKey::GetHashCode() { auto code = Slang::GetHashCode(inst->op); - code = combineHash(code, Slang::GetHashCode(inst->type)); + code = combineHash(code, Slang::GetHashCode(inst->getFullType())); code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[0])); code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[1])); return code; @@ -1009,7 +1063,7 @@ namespace Slang IRConstant keyInst; memset(&keyInst, 0, sizeof(keyInst)); keyInst.op = op; - keyInst.type = type; + keyInst.typeUse.usedValue = type; memcpy(&keyInst.u, value, valueSize); IRConstantKey key; @@ -1029,7 +1083,7 @@ namespace Slang // way: we will construct a temporary instruction and // then use it to look up in a cache of instructions. - irValue = createValue<IRConstant>(builder, op, type); + irValue = createInst<IRConstant>(builder, op, type); memcpy(&irValue->u, value, valueSize); key.inst = irValue; @@ -1049,7 +1103,7 @@ namespace Slang return findOrEmitConstant( this, kIROp_boolConst, - getSession()->getBoolType(), + getBoolType(), sizeof(value), &value); } @@ -1074,72 +1128,330 @@ namespace Slang &value); } - IRUndefined* IRBuilder::emitUndefined(IRType* type) + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + UInt operandListCount, + UInt const* listOperandCounts, + IRInst* const* const* listOperands) { - auto inst = createInst<IRUndefined>( - this, - kIROp_undefined, - type); + UInt operandCount = 0; + for (UInt ii = 0; ii < operandListCount; ++ii) + { + operandCount += listOperandCounts[ii]; + } + + // We are going to create a dummy instruction on the stack, + // which will be used as a key for lookup, so see if we + // already have an equivalent instruction available to use. + + size_t keySize = sizeof(IRInst) + operandCount * sizeof(IRUse); + IRInst* keyInst = (IRInst*) malloc(keySize); + memset(keyInst, 0, keySize); + + new(keyInst) IRInst(); + keyInst->op = op; + keyInst->typeUse.usedValue = type; + keyInst->operandCount = (uint32_t) operandCount; + + IRUse* operand = keyInst->getOperands(); + for (UInt ii = 0; ii < operandListCount; ++ii) + { + UInt listOperandCount = listOperandCounts[ii]; + for (UInt jj = 0; jj < listOperandCount; ++jj) + { + operand->usedValue = listOperands[ii][jj]; + operand++; + } + } + + IRInstKey key; + key.inst = keyInst; + + IRInst* foundInst = nullptr; + bool found = builder->sharedBuilder->globalValueNumberingMap.TryGetValue(key, foundInst); + + free((void*)keyInst); + + if (found) + { + return foundInst; + } + + // If no instruction was found, then we need to emit it. + + IRInst* inst = createInstImpl<IRInst>( + builder, + op, + type, + 0, + nullptr, + operandListCount, + listOperandCounts, + listOperands); + addHoistableInst(builder, inst); + + key.inst = inst; + builder->sharedBuilder->globalValueNumberingMap.Add(key, inst); - addInst(inst); - return inst; } - IRInst* IRBuilder::getDeclRefVal( - DeclRefBase const& declRef) + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + UInt operandCount, + IRInst* const* operands) + { + return findOrEmitHoistableInst( + builder, + type, + op, + 1, + &operandCount, + &operands); + } + + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + IRInst* operand, + UInt operandCount, + IRInst* const* operands) + { + UInt counts[] = { 1, operandCount }; + IRInst* const* lists[] = { &operand, operands }; + + return findOrEmitHoistableInst( + builder, + type, + op, + 2, + counts, + lists); + } + + + IRType* IRBuilder::getType( + IROp op, + UInt operandCount, + IRInst* const* operands) { - // TODO: we should cache these... - auto irValue = createValue<IRDeclRef>( + return (IRType*) findOrEmitHoistableInst( this, - kIROp_decl_ref, - nullptr); - irValue->declRef = DeclRef<Decl>(declRef.decl, declRef.substitutions); + nullptr, + op, + operandCount, + operands); + } - addHoistableInst(this, irValue); + IRType* IRBuilder::getType( + IROp op) + { + return getType(op, 0, nullptr); + } - return irValue; + IRBasicType* IRBuilder::getBasicType(BaseType baseType) + { + return (IRBasicType*)getType( + IROp((UInt)kIROp_FirstBasicType + (UInt)baseType)); + } + + IRBasicType* IRBuilder::getVoidType() + { + return (IRVoidType*)getType(kIROp_VoidType); + } + + IRBasicType* IRBuilder::getBoolType() + { + return (IRBoolType*)getType(kIROp_BoolType); + } + + IRBasicType* IRBuilder::getIntType() + { + return (IRBasicType*)getType(kIROp_IntType); + } + + IRBasicBlockType* IRBuilder::getBasicBlockType() + { + return (IRBasicBlockType*)getType(kIROp_BasicBlockType); + } + + IRTypeKind* IRBuilder::getTypeKind() + { + return (IRTypeKind*)getType(kIROp_TypeKind); + } + + IRGenericKind* IRBuilder::getGenericKind() + { + return (IRGenericKind*)getType(kIROp_GenericKind); + } + + IRPtrType* IRBuilder::getPtrType(IRType* valueType) + { + return (IRPtrType*) getPtrType(kIROp_PtrType, valueType); + } + + IROutType* IRBuilder::getOutType(IRType* valueType) + { + return (IROutType*) getPtrType(kIROp_OutType, valueType); + } + + IRInOutType* IRBuilder::getInOutType(IRType* valueType) + { + return (IRInOutType*) getPtrType(kIROp_InOutType, valueType); + } + + IRPtrTypeBase* IRBuilder::getPtrType(IROp op, IRType* valueType) + { + IRInst* operands[] = { valueType }; + return (IRPtrTypeBase*) getType( + op, + 1, + operands); + } + + IRArrayTypeBase* IRBuilder::getArrayTypeBase( + IROp op, + IRType* elementType, + IRInst* elementCount) + { + IRInst* operands[] = { elementType, elementCount }; + return (IRArrayTypeBase*)getType( + op, + op == kIROp_ArrayType ? 2 : 1, + operands); } - IRInst* IRBuilder::getTypeVal(IRType * type) + IRArrayType* IRBuilder::getArrayType( + IRType* elementType, + IRInst* elementCount) { - auto irValue = createValue<IRInst>( + IRInst* operands[] = { elementType, elementCount }; + return (IRArrayType*)getType( + kIROp_ArrayType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRUnsizedArrayType* IRBuilder::getUnsizedArrayType( + IRType* elementType) + { + IRInst* operands[] = { elementType }; + return (IRUnsizedArrayType*)getType( + kIROp_UnsizedArrayType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRVectorType* IRBuilder::getVectorType( + IRType* elementType, + IRInst* elementCount) + { + IRInst* operands[] = { elementType, elementCount }; + return (IRVectorType*)getType( + kIROp_VectorType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRMatrixType* IRBuilder::getMatrixType( + IRType* elementType, + IRInst* rowCount, + IRInst* columnCount) + { + IRInst* operands[] = { elementType, rowCount, columnCount }; + return (IRMatrixType*)getType( + kIROp_MatrixType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRFuncType* IRBuilder::getFuncType( + UInt paramCount, + IRType* const* paramTypes, + IRType* resultType) + { + return (IRFuncType*) findOrEmitHoistableInst( this, - kIROp_TypeType, - nullptr); - irValue->type = type; - if (auto typetype = dynamic_cast<TypeType*>(type)) - irValue->type = typetype->type; - return irValue; + nullptr, + kIROp_FuncType, + resultType, + paramCount, + (IRInst* const*) paramTypes); } - IRInst* IRBuilder::emitSpecializeInst( - Type* type, - IRInst* genericVal, - IRInst* specDeclRef) + IRConstExprRate* IRBuilder::getConstExprRate() + { + return (IRConstExprRate*)getType(kIROp_ConstExprRate); + } + + IRGroupSharedRate* IRBuilder::getGroupSharedRate() + { + return (IRGroupSharedRate*)getType(kIROp_GroupSharedRate); + } + + IRRateQualifiedType* IRBuilder::getRateQualifiedType( + IRRate* rate, + IRType* dataType) + { + IRInst* operands[] = { rate, dataType }; + return (IRRateQualifiedType*)getType( + kIROp_RateQualifiedType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + void IRBuilder::setDataType(IRInst* inst, IRType* dataType) + { + if (auto oldRateQualifiedType = as<IRRateQualifiedType>(inst->getFullType())) + { + // Construct a new rate-qualified type using the same rate. + + auto newRateQualifiedType = getRateQualifiedType( + oldRateQualifiedType->getRate(), + dataType); + + inst->setFullType(newRateQualifiedType); + } + else + { + // No rate? Just clobber the data type. + inst->setFullType(dataType); + } + } + + + IRUndefined* IRBuilder::emitUndefined(IRType* type) { - auto inst = createInst<IRSpecialize>( + auto inst = createInst<IRUndefined>( this, - kIROp_specialize, - type, - genericVal, - specDeclRef); + kIROp_undefined, + type); + addInst(inst); + return inst; } IRInst* IRBuilder::emitSpecializeInst( - Type* type, + IRType* type, IRInst* genericVal, - DeclRef<Decl> specDeclRef) + UInt argCount, + IRInst* const* args) { - auto specDeclRefVal = getDeclRefVal(specDeclRef); - auto inst = createInst<IRSpecialize>( + auto inst = createInstWithTrailingArgs<IRSpecialize>( this, - kIROp_specialize, + kIROp_Specialize, type, - genericVal, - specDeclRefVal); + 1, + &genericVal, + argCount, + args); + addInst(inst); return inst; } @@ -1155,45 +1467,7 @@ namespace Slang type, witnessTableVal, interfaceMethodVal); - addInst(inst); - return inst; - } - IRInst* IRBuilder::emitLookupInterfaceMethodInst( - IRType* type, - DeclRef<Decl> witnessTableDeclRef, - DeclRef<Decl> interfaceMethodDeclRef) - { - auto witnessTableVal = getDeclRefVal(witnessTableDeclRef); - DeclRef<Decl> removeSubstDeclRef = interfaceMethodDeclRef; - removeSubstDeclRef.substitutions = SubstitutionSet(); - auto interfaceMethodVal = getDeclRefVal(removeSubstDeclRef); - return emitLookupInterfaceMethodInst(type, witnessTableVal, interfaceMethodVal); - } - - IRInst* IRBuilder::emitLookupInterfaceMethodInst( - IRType* type, - IRInst* witnessTableVal, - DeclRef<Decl> interfaceMethodDeclRef) - { - DeclRef<Decl> removeSubstDeclRef = interfaceMethodDeclRef; - removeSubstDeclRef.substitutions = SubstitutionSet(); - auto interfaceMethodVal = getDeclRefVal(removeSubstDeclRef); - return emitLookupInterfaceMethodInst(type, witnessTableVal, interfaceMethodVal); - } - - IRInst* IRBuilder::emitFindWitnessTable( - DeclRef<Decl> baseTypeDeclRef, - IRType* interfaceType) - { - auto interfaceTypeDeclRef = interfaceType->AsDeclRefType(); - SLANG_ASSERT(interfaceTypeDeclRef); - auto inst = createInst<IRLookupWitnessTable>( - this, - kIROp_lookup_witness_table, - interfaceType, - getDeclRefVal(baseTypeDeclRef), - getDeclRefVal(interfaceTypeDeclRef->declRef)); addInst(inst); return inst; } @@ -1279,10 +1553,12 @@ namespace Slang auto moduleInst = createInstImpl<IRModuleInst>( module, this, - sizeof(IRModuleInst), kIROp_Module, nullptr, 0, + nullptr, + 0, + nullptr, nullptr); module->moduleInst = moduleInst; @@ -1290,58 +1566,103 @@ namespace Slang } void addGlobalValue( - IRModule* module, + IRBuilder* builder, IRGlobalValue* value) { - if(!module) - return; + // Try to find a suitable parent for the + // global value we are emitting. + // + // We will start out search at the current + // parent instruction for the builder, and + // possibly work our way up. + // + auto parent = builder->insertIntoParent; + while(parent) + { + // Inserting into the top level of a module? + // That is fine, and we can stop searching. + if (as<IRModuleInst>(parent)) + break; - value->insertAtEnd(module->moduleInst); + // Inserting into a basic block inside of + // a generic? That is okay too. + if (auto block = as<IRBlock>(parent)) + { + if (as<IRGeneric>(block->parent)) + break; + } + + // Otherwise, move up the chain. + parent = parent->parent; + } + + // If we somehow ran out of parents (possibly + // because an instruction wasn't linked into + // the full hierarchy yet), then we will + // fall back to inserting into the overall module. + if (!parent) + { + parent = builder->getModule()->getModuleInst(); + } + + // If it turns out that we are inserting into the + // current "insert into" parent for the builder, then + // we need to respect its "insert before" setting + // as well. + if (parent == builder->insertIntoParent + && builder->insertBeforeInst) + { + value->insertBefore(builder->insertBeforeInst); + } + else + { + value->insertAtEnd(parent); + } } IRFunc* IRBuilder::createFunc() { - IRFunc* rsFunc = createValue<IRFunc>( + IRFunc* rsFunc = createInst<IRFunc>( this, kIROp_Func, nullptr); maybeSetSourceLoc(this, rsFunc); - addGlobalValue(getModule(), rsFunc); + addGlobalValue(this, rsFunc); return rsFunc; } IRGlobalVar* IRBuilder::createGlobalVar( IRType* valueType) { - auto ptrType = getSession()->getPtrType(valueType); - IRGlobalVar* globalVar = createValue<IRGlobalVar>( + auto ptrType = getPtrType(valueType); + IRGlobalVar* globalVar = createInst<IRGlobalVar>( this, - kIROp_global_var, + kIROp_GlobalVar, ptrType); maybeSetSourceLoc(this, globalVar); - addGlobalValue(getModule(), globalVar); + addGlobalValue(this, globalVar); return globalVar; } IRGlobalConstant* IRBuilder::createGlobalConstant( IRType* valueType) { - IRGlobalConstant* globalConstant = createValue<IRGlobalConstant>( + IRGlobalConstant* globalConstant = createInst<IRGlobalConstant>( this, - kIROp_global_constant, + kIROp_GlobalConstant, valueType); maybeSetSourceLoc(this, globalConstant); - addGlobalValue(getModule(), globalConstant); + addGlobalValue(this, globalConstant); return globalConstant; } IRWitnessTable* IRBuilder::createWitnessTable() { - IRWitnessTable* witnessTable = createValue<IRWitnessTable>( + IRWitnessTable* witnessTable = createInst<IRWitnessTable>( this, - kIROp_witness_table, + kIROp_WitnessTable, nullptr); - addGlobalValue(getModule(), witnessTable); + addGlobalValue(this, witnessTable); return witnessTable; } @@ -1352,7 +1673,7 @@ namespace Slang { IRWitnessTableEntry* entry = createInst<IRWitnessTableEntry>( this, - kIROp_witness_table_entry, + kIROp_WitnessTableEntry, nullptr, requirementKey, satisfyingVal); @@ -1365,6 +1686,68 @@ namespace Slang return entry; } + IRStructType* IRBuilder::createStructType() + { + IRStructType* structType = createInst<IRStructType>( + this, + kIROp_StructType, + nullptr); + addGlobalValue(this, structType); + return structType; + } + + IRStructKey* IRBuilder::createStructKey() + { + IRStructKey* structKey = createInst<IRStructKey>( + this, + kIROp_StructKey, + nullptr); + addGlobalValue(this, structKey); + return structKey; + } + + // Create a field nested in a struct type, declaring that + // the specified field key maps to a field with the specified type. + IRStructField* IRBuilder::createStructField( + IRStructType* structType, + IRStructKey* fieldKey, + IRType* fieldType) + { + IRInst* operands[] = { fieldKey, fieldType }; + IRStructField* field = (IRStructField*) createInstWithTrailingArgs<IRInst>( + this, + kIROp_StructField, + nullptr, + 0, + nullptr, + 2, + operands); + + if (structType) + { + field->insertAtEnd(structType); + } + + return field; + } + + IRGeneric* IRBuilder::createGeneric() + { + IRGeneric* irGeneric = createInst<IRGeneric>( + this, + kIROp_Generic, + nullptr); + return irGeneric; + } + + IRGeneric* IRBuilder::emitGeneric() + { + auto irGeneric = createGeneric(); + addGlobalValue(this, irGeneric); + return irGeneric; + } + + IRWitnessTable * IRBuilder::lookupWitnessTable(Name* mangledName) { IRWitnessTable * result; @@ -1381,10 +1764,10 @@ namespace Slang IRBlock* IRBuilder::createBlock() { - return createValue<IRBlock>( + return createInst<IRBlock>( this, kIROp_Block, - getSession()->getIRBasicBlockType()); + getBasicBlockType()); } IRBlock* IRBuilder::emitBlock() @@ -1409,7 +1792,7 @@ namespace Slang IRParam* IRBuilder::createParam( IRType* type) { - auto param = createValue<IRParam>( + auto param = createInst<IRParam>( this, kIROp_Param, type); @@ -1430,7 +1813,7 @@ namespace Slang IRVar* IRBuilder::emitVar( IRType* type) { - auto allocatedType = getSession()->getPtrType(type); + auto allocatedType = getPtrType(type); auto inst = createInst<IRVar>( this, kIROp_Var, @@ -1449,12 +1832,12 @@ namespace Slang // results) at the "default" rate of the parent function, // unless a subsequent analysis pass constraints it. - RefPtr<Type> valueType; - if(auto ptrType = ptr->getDataType()->As<PtrTypeBase>()) + IRType* valueType = nullptr; + if(auto ptrType = as<IRPtrTypeBase>(ptr->getDataType())) { valueType = ptrType->getValueType(); } - else if(auto ptrLikeType = ptr->getDataType()->As<PointerLikeType>()) + else if(auto ptrLikeType = as<IRPointerLikeType>(ptr->getDataType())) { valueType = ptrLikeType->getElementType(); } @@ -1465,15 +1848,20 @@ namespace Slang return nullptr; } - // Ugly special case: the result of loading from `groupshared` - // memory should not itself be `groupshared`. + // Ugly special case: if the front-end created a variable with + // type `Ptr<@R T>` instead of `@R Ptr<T>`, then the above + // logic will yield `@R T` instead of `T`, and we need to + // try and fix that up here. + // + // TODO: Lowering to the IR should be fixed to never create + // that case: rate-qualified types should only be allowed + // to appear as the type of an instruction, and should not + // be allowed as operands to type constructors (except + // in special cases we decide to allow). // - // TODO: This special case will go away once `GroupSharedType` - // is replaced by a `GroupSharedRate` that gets used together - // with `RateQualifiedType`. - if(auto rateType = valueType->As<GroupSharedType>()) + if(auto rateType = as<IRRateQualifiedType>(valueType)) { - valueType = rateType->valueType; + valueType = rateType->getValueType(); } auto inst = createInst<IRLoad>( @@ -1589,7 +1977,7 @@ namespace Slang UInt elementCount, UInt const* elementIndices) { - auto intType = getSession()->getBuiltinType(BaseType::Int); + auto intType = getBasicType(BaseType::Int); IRInst* irElementIndices[4]; for (UInt ii = 0; ii < elementCount; ++ii) @@ -1631,7 +2019,7 @@ namespace Slang UInt elementCount, UInt const* elementIndices) { - auto intType = getSession()->getBuiltinType(BaseType::Int); + auto intType = getBasicType(BaseType::Int); IRInst* irElementIndices[4]; for (UInt ii = 0; ii < elementCount; ++ii) @@ -1802,6 +2190,30 @@ namespace Slang return inst; } + IRGlobalGenericParam* IRBuilder::emitGlobalGenericParam() + { + IRGlobalGenericParam* irGenericParam = createInst<IRGlobalGenericParam>( + this, + kIROp_GlobalGenericParam, + nullptr); + addGlobalValue(this, irGenericParam); + return irGenericParam; + } + + IRBindGlobalGenericParam* IRBuilder::emitBindGlobalGenericParam( + IRInst* param, + IRInst* val) + { + auto inst = createInst<IRBindGlobalGenericParam>( + this, + kIROp_BindGlobalGenericParam, + nullptr, + param, + val); + addInst(inst); + return inst; + } + IRHighLevelDeclDecoration* IRBuilder::addHighLevelDeclDecoration(IRInst* inst, Decl* decl) { auto decoration = addDecoration<IRHighLevelDeclDecoration>(inst, kIRDecorationOp_HighLevelDecl); @@ -1873,6 +2285,11 @@ namespace Slang bool opHasResult(IRInst* inst); + bool instHasUses(IRInst* inst) + { + return inst->firstUse != nullptr; + } + static UInt getID( IRDumpContext* context, IRInst* value) @@ -1881,7 +2298,7 @@ namespace Slang if (context->mapValueToID.TryGetValue(value, id)) return id; - if (opHasResult(value)) + if (opHasResult(value) || instHasUses(value)) { id = context->idCounter++; } @@ -1900,33 +2317,30 @@ namespace Slang return; } - switch(inst->op) + if (auto globalValue = as<IRGlobalValue>(inst)) { - case kIROp_Func: - case kIROp_global_var: - case kIROp_global_constant: - case kIROp_witness_table: + auto mangledName = globalValue->mangledName; + if(mangledName) { - auto irFunc = (IRFunc*) inst; - dump(context, "@"); - dump(context, getText(irFunc->mangledName).Buffer()); - } - break; - - default: - { - UInt id = getID(context, inst); - if (id) + auto mangledNameText = getText(mangledName); + if (mangledNameText.Length() > 0) { - dump(context, "%"); - dump(context, id); - } - else - { - dump(context, "_"); + dump(context, "@"); + dump(context, mangledNameText.Buffer()); + return; } } - break; + } + + UInt id = getID(context, inst); + if (id) + { + dump(context, "%"); + dump(context, id); + } + else + { + dump(context, "_"); } } @@ -1945,7 +2359,7 @@ namespace Slang // TODO: we should have a dedicated value for the `undef` case if (!inst) { - dump(context, "undef"); + dumpID(context, inst); return; } @@ -1963,16 +2377,6 @@ namespace Slang dump(context, ((IRConstant*)inst)->u.intVal ? "true" : "false"); return; - case kIROp_TypeType: - dumpType(context, (IRType*)inst); - return; - - case kIROp_decl_ref: - dump(context, "$\""); - dumpDeclRef(context, ((IRDeclRef*)inst)->declRef); - dump(context, "\""); - return; - default: break; } @@ -1980,123 +2384,6 @@ namespace Slang dumpID(context, inst); } - static void dump( - IRDumpContext* context, - Name* name) - { - dump(context, getText(name).Buffer()); - } - - static void dumpVal( - IRDumpContext* context, - Val* val) - { - if(auto type = dynamic_cast<Type*>(val)) - { - dumpType(context, type); - } - else if(auto constIntVal = dynamic_cast<ConstantIntVal*>(val)) - { - dump(context, constIntVal->value); - } - else if(auto genericParamVal = dynamic_cast<GenericParamIntVal*>(val)) - { - dumpDeclRef(context, genericParamVal->declRef); - } - else if(auto declaredSubtypeWitness = dynamic_cast<DeclaredSubtypeWitness*>(val)) - { - dump(context, "DeclaredSubtypeWitness("); - dumpType(context, declaredSubtypeWitness->sub); - dump(context, ", "); - dumpType(context, declaredSubtypeWitness->sup); - dump(context, ", "); - dumpDeclRef(context, declaredSubtypeWitness->declRef); - dump(context, ")"); - } - else if (auto proxyVal = dynamic_cast<IRProxyVal*>(val)) - { - dumpOperand(context, proxyVal->inst.get()); - } - else - { - dump(context, "???"); - } - } - - static void dumpDeclRef( - IRDumpContext* context, - DeclRef<Decl> const& declRef) - { - auto decl = declRef.getDecl(); - - auto parentDeclRef = declRef.GetParent(); - auto genericParentDeclRef = parentDeclRef.As<GenericDecl>(); - if (genericParentDeclRef) - { - if (genericParentDeclRef.getDecl()->inner.Ptr() == decl) - { - parentDeclRef = genericParentDeclRef.GetParent(); - } - else - { - genericParentDeclRef = DeclRef<GenericDecl>(); - } - } - - if(parentDeclRef.As<ModuleDecl>()) - { - parentDeclRef = DeclRef<ContainerDecl>(); - } - else if(parentDeclRef.As<GenericDecl>()) - { - parentDeclRef = DeclRef<ContainerDecl>(); - } - - if(parentDeclRef) - { - dumpDeclRef(context, parentDeclRef); - dump(context, "."); - } - dump(context, decl->getName()); - if (auto genericTypeConstraintDecl = dynamic_cast<GenericTypeConstraintDecl*>(decl)) - { - dump(context, "{"); - dumpType(context, genericTypeConstraintDecl->sub); - dump(context, " : "); - dumpType(context, genericTypeConstraintDecl->sup); - dump(context, "}"); - } - else if (auto inheritanceDecl = dynamic_cast<InheritanceDecl*>(decl)) - { - dump(context, "{ _ : "); - dumpType(context, inheritanceDecl->base); - dump(context, "}"); - } - - if(genericParentDeclRef) - { - auto subst = declRef.substitutions.genericSubstitutions; - if( !subst || subst->genericDecl != genericParentDeclRef.getDecl() ) - { - // No actual substitutions in place here - dump(context, "<>"); - } - else - { - auto args = subst->args; - bool first = true; - dump(context, "<"); - for(auto aa : args) - { - if(!first) dump(context, ","); - dumpVal(context, aa); - first = false; - } - dump(context, ">"); - } - } - } - static void dumpType( IRDumpContext* context, IRType* type) @@ -2107,84 +2394,10 @@ namespace Slang return; } - if(auto funcType = type->As<FuncType>()) - { - UInt paramCount = funcType->getParamCount(); - dump(context, "("); - for( UInt pp = 0; pp < paramCount; ++pp ) - { - if(pp != 0) dump(context, ", "); - dumpType(context, funcType->getParamType(pp)); - } - dump(context, ") -> "); - dumpType(context, funcType->getResultType()); - } - else if(auto arrayType = type->As<ArrayExpressionType>()) - { - dumpType(context, arrayType->baseType); - dump(context, "["); - if(auto elementCount = arrayType->ArrayLength) - { - dumpVal(context, elementCount); - } - dump(context, "]"); - } - else if(auto declRefType = type->As<DeclRefType>()) - { - dumpDeclRef(context, declRefType->declRef); - } - else if(auto groupSharedType = type->As<GroupSharedType>()) - { - dump(context, "@ThreadGroup "); - dumpType(context, groupSharedType->valueType); - } - else if(auto rateQualifiedType = type->As<RateQualifiedType>()) - { - dump(context, "@"); - dumpType(context, rateQualifiedType->rate); - dump(context, " "); - dumpType(context, rateQualifiedType->valueType); - } - else if(auto constExprRate = type->As<ConstExprRate>()) - { - dump(context, "ConstExpr"); - } - else - { - // Need a default case here - dump(context, "???"); - } - -#if 0 - auto op = type->op; - auto opInfo = kIROpInfos[op]; - - switch (op) - { - case kIROp_StructType: - dumpID(context, type); - break; - - default: - { - dump(context, opInfo.name); - UInt argCount = type->getArgCount(); - - if (argCount > 1) - { - dump(context, "<"); - for (UInt aa = 1; aa < argCount; ++aa) - { - if (aa != 1) dump(context, ","); - dumpOperand(context, type->getArg(aa)); - - } - dump(context, ">"); - } - } - break; - } -#endif + // TODO: we should consider some special-case printing + // for types, so that the IR doesn't get too hard to read + // (always having to back-reference for what a type expands to) + dumpOperand(context, type); } static void dumpInstTypeClause( @@ -2245,60 +2458,11 @@ namespace Slang } } - void dumpGenericSignature( + void dumpIRDecorations( IRDumpContext* context, - GenericDecl* genericDecl) - { - for( auto pp = genericDecl->ParentDecl; pp; pp = pp->ParentDecl ) - { - if( auto genericAncestor = dynamic_cast<GenericDecl*>(pp) ) - { - dumpGenericSignature(context, genericAncestor); - break; - } - } - - dump(context, " <"); - bool first = true; - for (auto mm : genericDecl->Members) - { - - if( auto typeParamDecl = mm.As<GenericTypeParamDecl>() ) - { - if (!first) dump(context, ", "); - dumpDeclRef(context, makeDeclRef(typeParamDecl.Ptr())); - first = false; - } - else if( auto valueParamDecl = mm.As<GenericTypeParamDecl>() ) - { - if (!first) dump(context, ", "); - dumpDeclRef(context, makeDeclRef(valueParamDecl.Ptr())); - first = false; - } - } - first = true; - for (auto mm : genericDecl->Members) - { - if( auto constraintDecl = mm.As<GenericTypeConstraintDecl>() ) - { - if (!first) dump(context, ", "); - else dump(context, " where "); - - dumpType(context, constraintDecl->sub); - dump(context, " : "); - dumpType(context, constraintDecl->sup); - first = false; - } - } - dump(context, ">"); - } - - void dumpIRFunc( - IRDumpContext* context, - IRFunc* func) + IRInst* inst) { - - for( auto dd = func->firstDecoration; dd; dd = dd->next ) + for( auto dd = inst->firstDecoration; dd; dd = dd->next ) { switch( dd->op ) { @@ -2316,21 +2480,26 @@ namespace Slang } } + } + + void dumpIRGlobalValueWithCode( + IRDumpContext* context, + IRGlobalValueWithCode* code) + { + // TODO: should apply this to all instructions + dumpIRDecorations(context, code); + + auto opInfo = getIROpInfo(code->op); dump(context, "\n"); dumpIndent(context); - dump(context, "ir_func "); - dumpID(context, func); + dump(context, opInfo.name); + dump(context, " "); + dumpID(context, code); - if (func->getGenericDecl()) - { - dump(context, " "); - dumpGenericSignature(context, func->getGenericDecl()); - } + dumpInstTypeClause(context, code->getFullType()); - dumpInstTypeClause(context, func->getType()); - - if (!func->getFirstBlock()) + if (!code->getFirstBlock()) { // Just a declaration. dump(context, ";\n"); @@ -2343,9 +2512,9 @@ namespace Slang dump(context, "{\n"); context->indent++; - for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock()) + for (auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock()) { - if (bb != func->getFirstBlock()) + if (bb != code->getFirstBlock()) dump(context, "\n"); dumpBlock(context, bb); } @@ -2360,57 +2529,64 @@ namespace Slang IRDumpContext dumpContext; StringBuilder sbDump; dumpContext.builder = &sbDump; - dumpIRFunc(&dumpContext, func); + dumpIRGlobalValueWithCode(&dumpContext, func); auto strFunc = sbDump.ToString(); return strFunc; } - void dumpIRGlobalVar( + void dumpIRWitnessTableEntry( + IRDumpContext* context, + IRWitnessTableEntry* entry) + { + dump(context, "witness_table_entry("); + dumpOperand(context, entry->requirementKey.get()); + dump(context, ","); + dumpOperand(context, entry->satisfyingVal.get()); + dump(context, ")\n"); + } + + void dumpIRParentInst( IRDumpContext* context, - IRGlobalVar* var) + IRParentInst* inst) { + // TODO: should apply this to all instructions + dumpIRDecorations(context, inst); + + auto opInfo = getIROpInfo(inst->op); + dump(context, "\n"); dumpIndent(context); - dump(context, "ir_global_var "); - dumpID(context, var); - dumpInstTypeClause(context, var->getFullType()); + dump(context, opInfo.name); + dump(context, " "); + dumpID(context, inst); - // TODO: deal with the case where a global - // might have embedded initialization logic. + dumpInstTypeClause(context, inst->getFullType()); - dump(context, ";\n"); - } + if (!inst->getFirstChild()) + { + // Empty. + dump(context, ";\n"); + return; + } - void dumpIRGlobalConstant( - IRDumpContext* context, - IRGlobalConstant* val) - { dump(context, "\n"); - dumpIndent(context); - dump(context, "ir_global_constant "); - dumpID(context, val); - dumpInstTypeClause(context, val->getFullType()); - // TODO: deal with the case where a global - // might have embedded initialization logic. + dumpIndent(context); + dump(context, "{\n"); + context->indent++; - dump(context, ";\n"); - } + for (auto child = inst->getFirstChild(); child; child = child->getNextInst()) + { + dumpInst(context, child); + } - void dumpIRWitnessTableEntry( - IRDumpContext* context, - IRWitnessTableEntry* entry) - { - dump(context, "witness_table_entry("); - dumpOperand(context, entry->requirementKey.get()); - dump(context, ","); - dumpOperand(context, entry->satisfyingVal.get()); - dump(context, ")\n"); + context->indent--; + dump(context, "}\n"); } - void dumpIRWitnessTable( + void dumpIRGeneric( IRDumpContext* context, - IRWitnessTable* witnessTable) + IRGeneric* witnessTable) { dump(context, "\n"); dumpIndent(context); @@ -2447,22 +2623,18 @@ namespace Slang switch (op) { case kIROp_Func: - dumpIRFunc(context, (IRFunc*)inst); - return; - - case kIROp_global_var: - dumpIRGlobalVar(context, (IRGlobalVar*)inst); - return; - - case kIROp_global_constant: - dumpIRGlobalConstant(context, (IRGlobalConstant*)inst); + case kIROp_GlobalVar: + case kIROp_GlobalConstant: + case kIROp_Generic: + dumpIRGlobalValueWithCode(context, (IRGlobalValueWithCode*)inst); return; - case kIROp_witness_table: - dumpIRWitnessTable(context, (IRWitnessTable*)inst); + case kIROp_WitnessTable: + case kIROp_StructType: + dumpIRParentInst(context, (IRWitnessTable*)inst); return; - case kIROp_witness_table_entry: + case kIROp_WitnessTableEntry: dumpIRWitnessTableEntry(context, (IRWitnessTableEntry*)inst); return; @@ -2473,31 +2645,30 @@ namespace Slang // Okay, we have a seemingly "ordinary" op now dumpIndent(context); - auto opInfo = &kIROpInfos[op]; - auto type = inst->getFullType(); + auto opInfo = getIROpInfo(op); auto dataType = inst->getDataType(); + auto rate = inst->getRate(); - if (!dataType) + if(rate) { - // No result, okay... + dump(context, "@"); + dumpOperand(context, rate); + dump(context, " "); + } + + if(opHasResult(inst) || instHasUses(inst)) + { + dump(context, "let "); + dumpID(context, inst); + dumpInstTypeClause(context, dataType); + dump(context, "\t= "); } else { - auto basicType = dataType->As<BasicExpressionType>(); - if (basicType && basicType->baseType == BaseType::Void) - { - // No result, okay... - } - else - { - dump(context, "let "); - dumpID(context, inst); - dumpInstTypeClause(context, type); - dump(context, "\t= "); - } + // No result, okay... } - dump(context, opInfo->name); + dump(context, opInfo.name); UInt argCount = inst->getOperandCount(); UInt ii = 0; @@ -2531,7 +2702,6 @@ namespace Slang case kIROp_IntLit: case kIROp_FloatLit: case kIROp_boolConst: - case kIROp_decl_ref: dumpOperand(context, inst); break; @@ -2596,24 +2766,29 @@ namespace Slang // // - Type* IRInst::getRate() + IRRate* IRInst::getRate() { - if(auto rateQualifiedType = type->As<RateQualifiedType>()) - return rateQualifiedType->rate; + if(auto rateQualifiedType = as<IRRateQualifiedType>(getFullType())) + return rateQualifiedType->getRate(); return nullptr; } - Type* IRInst::getDataType() + IRType* IRInst::getDataType() { - if(auto rateQualifiedType = type->As<RateQualifiedType>()) - return rateQualifiedType->valueType; + auto type = getFullType(); + if(auto rateQualifiedType = as<IRRateQualifiedType>(type)) + return rateQualifiedType->getValueType(); return type; } void IRInst::replaceUsesWith(IRInst* other) { + // Safety check: don't try to replace something with itself. + if(other == this) + return; + // We will walk through the list of uses for the current // instruction, and make them point to the other inst. IRUse* ff = firstUse; @@ -2683,7 +2858,6 @@ namespace Slang void IRInst::dispose() { IRObject::dispose(); - type = decltype(type)(); } // Insert this instruction into the same basic block @@ -2862,7 +3036,7 @@ namespace Slang IRGlobalVar* addGlobalVariable( IRModule* module, - Type* valueType) + IRType* valueType) { auto session = module->session; @@ -2872,9 +3046,6 @@ namespace Slang IRBuilder builder; builder.sharedBuilder = &shared; - - RefPtr<PtrType> ptrType = session->getPtrType(valueType); - return builder.createGlobalVar(valueType); } @@ -2965,11 +3136,11 @@ namespace Slang { struct Element { + IRStructKey* key; ScalarizedVal val; - DeclRef<Decl> declRef; }; - RefPtr<Type> type; + IRType* type; List<Element> elements; }; @@ -2978,8 +3149,8 @@ namespace Slang struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl { ScalarizedVal val; - RefPtr<Type> actualType; // the actual type of `val` - RefPtr<Type> pretendType; // the type this value pretends to have + IRType* actualType; // the actual type of `val` + IRType* pretendType; // the type this value pretends to have }; struct GlobalVaryingDeclarator @@ -2990,21 +3161,21 @@ namespace Slang }; Flavor flavor; - IntVal* elementCount; + IRInst* elementCount; GlobalVaryingDeclarator* next; }; struct GLSLSystemValueInfo { // The name of the built-in GLSL variable - char const* name; + char const* name; // The name of an outer array that wraps // the variable, in the case of a GS input char const* outerArrayName; // The required type of the built-in variable - RefPtr<Type> requiredType; + IRType* requiredType; }; void requireGLSLVersionImpl( @@ -3041,6 +3212,9 @@ namespace Slang { return sink; } + + IRBuilder* builder; + IRBuilder* getBuilder() { return builder; } }; GLSLSystemValueInfo* getGLSLSystemValueInfo( @@ -3059,7 +3233,7 @@ namespace Slang auto semanticName = semanticNameSpelling.ToLower(); - RefPtr<Type> requiredType; + IRType* requiredType = nullptr; if(semanticName == "sv_position") { @@ -3190,7 +3364,7 @@ namespace Slang } name = "gl_Layer"; - requiredType = context->session->getBuiltinType(BaseType::Int); + requiredType = context->getBuilder()->getBasicType(BaseType::Int); } else if (semanticName == "sv_sampleindex") { @@ -3262,7 +3436,7 @@ namespace Slang ScalarizedVal createSimpleGLSLGlobalVarying( GLSLLegalizationContext* context, IRBuilder* builder, - Type* inType, + IRType* inType, VarLayout* inVarLayout, TypeLayout* inTypeLayout, LayoutResourceKind kind, @@ -3279,7 +3453,7 @@ namespace Slang stage, &systemValueInfoStorage); - RefPtr<Type> type = inType; + IRType* type = inType; // A system-value semantic might end up needing to override the type // that the user specified. @@ -3295,12 +3469,12 @@ namespace Slang { assert(dd->flavor == GlobalVaryingDeclarator::Flavor::array); - RefPtr<ArrayExpressionType> arrayType = builder->getSession()->getArrayType( + auto arrayType = builder->getArrayType( type, dd->elementCount); RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); - arrayTypeLayout->type = arrayType; +// arrayTypeLayout->type = arrayType; arrayTypeLayout->rules = typeLayout->rules; arrayTypeLayout->originalElementTypeLayout = typeLayout; arrayTypeLayout->elementTypeLayout = typeLayout; @@ -3355,7 +3529,7 @@ namespace Slang // the actual type of the GLSL global. auto toType = inType; - if( !fromType->Equals(toType) ) + if( fromType != toType ) { RefPtr<ScalarizedTypeAdapterValImpl> typeAdapter = new ScalarizedTypeAdapterValImpl; typeAdapter->actualType = systemValueInfo->requiredType; @@ -3381,7 +3555,7 @@ namespace Slang ScalarizedVal createGLSLGlobalVaryingsImpl( GLSLLegalizationContext* context, IRBuilder* builder, - Type* type, + IRType* type, VarLayout* varLayout, TypeLayout* typeLayout, LayoutResourceKind kind, @@ -3389,31 +3563,31 @@ namespace Slang UInt bindingIndex, GlobalVaryingDeclarator* declarator) { - if( type->As<BasicExpressionType>() ) + if( as<IRBasicType>(type) ) { return createSimpleGLSLGlobalVarying( context, builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); } - else if( type->As<VectorExpressionType>() ) + else if( as<IRVectorType>(type) ) { return createSimpleGLSLGlobalVarying( context, builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); } - else if( type->As<MatrixExpressionType>() ) + else if( as<IRMatrixType>(type) ) { // TODO: a matrix-type varying should probably be handled like an array of rows return createSimpleGLSLGlobalVarying( context, builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); } - else if( auto arrayType = type->As<ArrayExpressionType>() ) + else if( auto arrayType = as<IRArrayType>(type) ) { // We will need to SOA-ize any nested types. - auto elementType = arrayType->baseType; - auto elementCount = arrayType->ArrayLength; + auto elementType = arrayType->getElementType(); + auto elementCount = arrayType->getElementCount(); auto arrayLayout = dynamic_cast<ArrayTypeLayout*>(typeLayout); SLANG_ASSERT(arrayLayout); auto elementTypeLayout = arrayLayout->elementTypeLayout; @@ -3434,7 +3608,7 @@ namespace Slang bindingIndex, &arrayDeclarator); } - else if( auto streamType = type->As<HLSLStreamOutputType>() ) + else if( auto streamType = as<IRHLSLStreamOutputType>(type)) { auto elementType = streamType->getElementType(); auto streamLayout = dynamic_cast<StreamOutputTypeLayout*>(typeLayout); @@ -3452,66 +3626,60 @@ namespace Slang bindingIndex, declarator); } - else if( auto declRefType = type->As<DeclRefType>() ) + else if(auto structType = as<IRStructType>(type)) { - auto declRef = declRefType->declRef; - if( auto structDeclRef = declRef.As<StructDecl>() ) - { - // This is either a user-defined struct, or a builtin type. - // TODO: exclude resource types here. + // We need to recurse down into the individual fields, + // and generate a variable for each of them. - // We need to recurse down into the individual fields, - // and generate a variable for each of them. + auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout); + SLANG_ASSERT(structTypeLayout); + RefPtr<ScalarizedTupleValImpl> tupleValImpl = new ScalarizedTupleValImpl(); - // Note: we can use the presence of a `StructTypeLayout` as - // a quick way to reject a bunch of types that aren't actually `struct`s - auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout); - if( structTypeLayout ) - { - RefPtr<ScalarizedTupleValImpl> tupleValImpl = new ScalarizedTupleValImpl(); + // Construct the actual type for the tuple (including any outer arrays) + IRType* fullType = type; + for( auto dd = declarator; dd; dd = dd->next ) + { + assert(dd->flavor == GlobalVaryingDeclarator::Flavor::array); + fullType = builder->getArrayType( + fullType, + dd->elementCount); + } - // Construct the actual type for the tuple (including any outer arrays) - RefPtr<Type> fullType = type; - for( auto dd = declarator; dd; dd = dd->next ) - { - assert(dd->flavor == GlobalVaryingDeclarator::Flavor::array); - fullType = builder->getSession()->getArrayType( - fullType, - dd->elementCount); - } + tupleValImpl->type = fullType; - tupleValImpl->type = fullType; + // Okay, we want to walk through the fields here, and + // generate one variable for each. + UInt fieldCounter = 0; + for(auto field : structType->getFields()) + { + UInt fieldIndex = fieldCounter++; - // Okay, we want to walk through the fields here, and - // generate one variable for each. - for( auto ff : structTypeLayout->fields ) - { - UInt fieldBindingIndex = bindingIndex; - if(auto fieldResInfo = ff->FindResourceInfo(kind)) - fieldBindingIndex += fieldResInfo->index; + auto fieldLayout = structTypeLayout->fields[fieldIndex]; - auto fieldVal = createGLSLGlobalVaryingsImpl( - context, - builder, - ff->typeLayout->type, - ff, - ff->typeLayout, - kind, - stage, - fieldBindingIndex, - declarator); - - ScalarizedTupleValImpl::Element element; - element.val = fieldVal; - element.declRef = ff->varDecl; - - tupleValImpl->elements.Add(element); - } + UInt fieldBindingIndex = bindingIndex; + if(auto fieldResInfo = fieldLayout->FindResourceInfo(kind)) + fieldBindingIndex += fieldResInfo->index; - return ScalarizedVal::tuple(tupleValImpl); - } + auto fieldVal = createGLSLGlobalVaryingsImpl( + context, + builder, + field->getFieldType(), + fieldLayout, + fieldLayout->typeLayout, + kind, + stage, + fieldBindingIndex, + declarator); + + ScalarizedTupleValImpl::Element element; + element.val = fieldVal; + element.key = field->getKey(); + + tupleValImpl->elements.Add(element); } + + return ScalarizedVal::tuple(tupleValImpl); } // Default case is to fall back on the simple behavior @@ -3523,7 +3691,7 @@ namespace Slang ScalarizedVal createGLSLGlobalVaryings( GLSLLegalizationContext* context, IRBuilder* builder, - Type* type, + IRType* type, VarLayout* layout, LayoutResourceKind kind, Stage stage) @@ -3536,27 +3704,44 @@ namespace Slang builder, type, layout, layout->typeLayout, kind, stage, bindingIndex, nullptr); } + IRType* getFieldType( + IRType* baseType, + IRStructKey* fieldKey) + { + if(auto structType = as<IRStructType>(baseType)) + { + for(auto ff : structType->getFields()) + { + if(ff->getKey() == fieldKey) + return ff->getFieldType(); + } + } + + SLANG_UNEXPECTED("no such field"); + UNREACHABLE_RETURN(nullptr); + } + ScalarizedVal extractField( IRBuilder* builder, ScalarizedVal const& val, UInt fieldIndex, - DeclRef<Decl> fieldDeclRef) + IRStructKey* fieldKey) { switch( val.flavor ) { case ScalarizedVal::Flavor::value: return ScalarizedVal::value( builder->emitFieldExtract( - GetType(fieldDeclRef.As<VarDeclBase>()), + getFieldType(val.irValue->getDataType(), fieldKey), val.irValue, - builder->getDeclRefVal(fieldDeclRef))); + fieldKey)); case ScalarizedVal::Flavor::address: return ScalarizedVal::address( builder->emitFieldAddress( - GetType(fieldDeclRef.As<VarDeclBase>()), + getFieldType(val.irValue->getDataType(), fieldKey), val.irValue, - builder->getDeclRefVal(fieldDeclRef))); + fieldKey)); case ScalarizedVal::Flavor::tuple: { @@ -3574,8 +3759,8 @@ namespace Slang ScalarizedVal adaptType( IRBuilder* builder, IRInst* val, - Type* toType, - Type* /*fromType*/) + IRType* toType, + IRType* /*fromType*/) { // TODO: actually consider what needs to go on here... return ScalarizedVal::value(builder->emitConstructorInst( @@ -3587,8 +3772,8 @@ namespace Slang ScalarizedVal adaptType( IRBuilder* builder, ScalarizedVal const& val, - Type* toType, - Type* fromType) + IRType* toType, + IRType* fromType) { switch( val.flavor ) { @@ -3647,7 +3832,7 @@ namespace Slang builder, left, ee, - rightElement.declRef); + rightElement.key); assign(builder, leftElementVal, rightElement.val); } } @@ -3672,7 +3857,7 @@ namespace Slang builder, right, ee, - leftTupleVal->elements[ee].declRef); + leftTupleVal->elements[ee].key); assign(builder, leftTupleVal->elements[ee].val, rightElementVal); } } @@ -3699,7 +3884,7 @@ namespace Slang ScalarizedVal getSubscriptVal( IRBuilder* builder, - Type* elementType, + IRType* elementType, ScalarizedVal val, IRInst* indexVal) { @@ -3715,7 +3900,7 @@ namespace Slang case ScalarizedVal::Flavor::address: return ScalarizedVal::address( builder->emitElementAddress( - builder->getSession()->getPtrType(elementType), + builder->getPtrType(elementType), val.irValue, indexVal)); @@ -3729,18 +3914,10 @@ namespace Slang UInt elementCount = inputTuple->elements.Count(); UInt elementCounter = 0; - auto declRefType = dynamic_cast<DeclRefType*>(elementType); - SLANG_RELEASE_ASSERT(declRefType); - - auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDecl>(); - SLANG_RELEASE_ASSERT(aggTypeDeclRef); - - for(auto fieldDeclRef : getMembersOfType<StructField>(aggTypeDeclRef)) + auto structType = as<IRStructType>(elementType); + for(auto field : structType->getFields()) { - if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) - continue; - - auto tupleElementType = GetType(fieldDeclRef); + auto tupleElementType = field->getFieldType(); UInt elementIndex = elementCounter++; @@ -3748,7 +3925,7 @@ namespace Slang auto inputElement = inputTuple->elements[elementIndex]; ScalarizedTupleValImpl::Element resultElement; - resultElement.declRef = inputElement.declRef; + resultElement.key = inputElement.key; resultElement.val = getSubscriptVal( builder, tupleElementType, @@ -3770,7 +3947,7 @@ namespace Slang ScalarizedVal getSubscriptVal( IRBuilder* builder, - Type* elementType, + IRType* elementType, ScalarizedVal val, UInt index) { @@ -3779,7 +3956,7 @@ namespace Slang elementType, val, builder->getIntValue( - builder->getSession()->getIntType(), + builder->getIntType(), index)); } @@ -3797,7 +3974,7 @@ namespace Slang UInt elementCount = tupleVal->elements.Count(); auto type = tupleVal->type; - if( auto arrayType = type.As<ArrayExpressionType>() ) + if( auto arrayType = as<IRArrayType>(type)) { // The tuple represent an array, which means that the // individual elements are expected to yield arrays as well. @@ -3806,13 +3983,13 @@ namespace Slang // then use these to construct our result. List<IRInst*> arrayElementVals; - UInt arrayElementCount = (UInt) GetIntVal(arrayType->ArrayLength); + UInt arrayElementCount = (UInt) GetIntVal(arrayType->getElementCount()); for( UInt ii = 0; ii < arrayElementCount; ++ii ) { auto arrayElementPseudoVal = getSubscriptVal( builder, - arrayType->baseType, + arrayType->getElementType(), val, ii); @@ -3945,6 +4122,8 @@ namespace Slang builder.sharedBuilder = &shared; builder.setInsertInto(func); + context.builder = &builder; + // We will start by looking at the return type of the // function, because that will enable us to do an // early-out check to avoid more work. @@ -3953,7 +4132,7 @@ namespace Slang // a `void` return type, because there is no work // to be done on its return value in that case. auto resultType = func->getResultType(); - if( resultType->Equals(session->getVoidType()) ) + if(as<IRVoidType>(resultType)) { // In this case, the function doesn't return a value // so we don't need to transform its `return` sites. @@ -4060,10 +4239,10 @@ namespace Slang // don't fit into the standard varying model. // For right now we are only doing special-case handling // of geometry shader output streams. - if( auto paramPtrType = paramType->As<OutTypeBase>() ) + if( auto paramPtrType = as<IROutTypeBase>(paramType) ) { auto valueType = paramPtrType->getValueType(); - if( auto gsStreamType = valueType->As<HLSLStreamOutputType>() ) + if( auto gsStreamType = as<IRHLSLStreamOutputType>(valueType) ) { // An output stream type like `TriangleStream<Foo>` should // more or less translate into `out Foo` (plus scalarization). @@ -4097,7 +4276,7 @@ namespace Slang // Is it calling the append operation? auto callee = ii->getOperand(0); - while( callee->op == kIROp_specialize ) + while( callee->op == kIROp_Specialize ) { callee = ((IRSpecialize*) callee)->getOperand(0); } @@ -4132,7 +4311,7 @@ namespace Slang // Is the parameter type a special pointer type // that indicates the parameter is used for `out` // or `inout` access? - if(auto paramPtrType = paramType->As<OutTypeBase>() ) + if(auto paramPtrType = as<IROutTypeBase>(paramType) ) { // Okay, we have the more interesting case here, // where the parameter was being passed by reference. @@ -4145,7 +4324,7 @@ namespace Slang auto localVariable = builder.emitVar(valueType); auto localVal = ScalarizedVal::address(localVariable); - if( auto inOutType = paramPtrType->As<InOutType>() ) + if( auto inOutType = as<IRInOutType>(paramPtrType) ) { // In the `in out` case we need to declare two // sets of global variables: one for the `in` @@ -4236,10 +4415,11 @@ namespace Slang // Finally, we need to patch up the type of the entry point, // because it is no longer accurate. - RefPtr<FuncType> voidFuncType = new FuncType(); - voidFuncType->setSession(session); - voidFuncType->resultType = session->getVoidType(); - func->type = voidFuncType; + IRFuncType* voidFuncType = builder.getFuncType( + 0, + nullptr, + builder.getVoidType()); + func->setFullType(voidFuncType); // TODO: we should technically be constructing // a new `EntryPointLayout` here to reflect @@ -4260,6 +4440,15 @@ namespace Slang RefPtr<IRSpecSymbol> nextWithSameName; }; + struct IRSpecEnv + { + IRSpecEnv* parent = nullptr; + + // A map from original values to their cloned equivalents. + typedef Dictionary<IRInst*, IRInst*> ClonedValueDictionary; + ClonedValueDictionary clonedValues; + }; + struct IRSharedSpecContext { // The code-generation target in use @@ -4277,16 +4466,38 @@ namespace Slang typedef Dictionary<Name*, RefPtr<IRSpecSymbol>> SymbolDictionary; SymbolDictionary symbols; - // A map from values in the original IR module - // to their equivalent in the cloned module. - typedef Dictionary<IRInst*, IRInst*> ClonedValueDictionary; - ClonedValueDictionary clonedValues; - SharedIRBuilder sharedBuilderStorage; IRBuilder builderStorage; - // Non-generic functions to be processed (for generic specialization context) - List<IRFunc*> workList; + // The "global" specialization environment. + IRSpecEnv globalEnv; + }; + + struct IRSharedGenericSpecContext : IRSharedSpecContext + { + // Instructions to be processed (for generic specialization context) + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + void addToWorkList(IRInst* inst) + { + if(!workListSet.Contains(inst)) + { + workList.Add(inst); + workListSet.Add(inst); + } + } + IRInst* popWorkList() + { + UInt count = workList.Count(); + if(count != 0) + { + IRInst* inst = workList[count - 1]; + workList.FastRemoveAt(count - 1); + workListSet.Remove(inst); + return inst; + } + return nullptr; + } }; struct IRSpecContextBase @@ -4305,13 +4516,23 @@ namespace Slang IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } - IRSharedSpecContext::ClonedValueDictionary& getClonedValues() { return getShared()->clonedValues; } + // The current specialization environment to use. + IRSpecEnv* env = nullptr; + IRSpecEnv* getEnv() + { + // TODO: need to actually establish environments on contexts we create. + // + // Or more realistically we need to change the whole approach + // to specialization and cloning so that we don't try to share + // logic between two very different cases. + + + return env; + } // The IR builder to use for creating nodes IRBuilder* builder; - SubstitutionSet subst; - // A callback to be used when a value that is not registerd in `clonedValues` // is needed during cloning. This gives the subtype a chance to intercept // the operation and clone (or not) as needed. @@ -4319,24 +4540,6 @@ namespace Slang { return originalVal; } - - // A callback used to clone (or not) types. - virtual RefPtr<Type> maybeCloneType(Type* originalType) - { - return originalType; - } - - // A callback used to clone (or not) a declaration reference - virtual DeclRef<Decl> maybeCloneDeclRef(DeclRef<Decl> const& declRef) - { - return declRef; - } - - // A callback used to clone (or not) a Val - virtual RefPtr<Val> maybeCloneVal(Val* val) - { - return val; - } }; void registerClonedValue( @@ -4347,19 +4550,12 @@ namespace Slang if(!originalValue) return; - // Note: setting the entry direclty here rather than - // using `Add` or `AddIfNotExists` because we can conceivably - // clone the same value (e.g., a basic block inside a generic - // function) multiple times, and that is okay, and we really - // just need to keep track of the most recent value. - - // TODO: The same thing could potentially be handled more - // cleanly by having a notion of scoping for these cloned-value - // mappings, so that we register cloned values for things - // inside of a function to a temporary mapping that we - // throw away after the function is done. - - context->getClonedValues()[originalValue] = clonedValue; + // TODO: now that things are scoped using environments, we + // shouldn't be running into the cases where a value with + // the same key already exists. This should be changed to + // an `Add()` call. + // + context->getEnv()->clonedValues[originalValue] = clonedValue; } // Information on values to use when registering a cloned value @@ -4425,6 +4621,22 @@ namespace Slang } break; + case kIRDecorationOp_Semantic: + { + auto originalDecoration = (IRSemanticDecoration*)dd; + auto newDecoration = context->builder->addDecoration<IRSemanticDecoration>(clonedValue); + newDecoration->semanticName = originalDecoration->semanticName; + } + break; + + case kIRDecorationOp_InterpolationMode: + { + auto originalDecoration = (IRInterpolationModeDecoration*)dd; + auto newDecoration = context->builder->addDecoration<IRInterpolationModeDecoration>(clonedValue); + newDecoration->mode = originalDecoration->mode; + } + break; + default: // Don't clone any decorations we don't understand. break; @@ -4435,46 +4647,37 @@ namespace Slang clonedValue->sourceLoc = originalValue->sourceLoc; } + // We use an `IRSpecContext` for the case where we are cloning + // code from one or more input modules to create a "linked" output + // module. Along the way, we will resolve profile-specific functions + // to the best definition for a given target. + // struct IRSpecContext : IRSpecContextBase { // Override the "maybe clone" logic so that we always clone virtual IRInst* maybeCloneValue(IRInst* originalVal) override; - - // Override teh "maybe clone" logic so that we carefully - // clone any IR proxy values inside substitutions - virtual DeclRef<Decl> maybeCloneDeclRef(DeclRef<Decl> const& declRef) override; - - virtual RefPtr<Type> maybeCloneType(Type* originalType) override; - virtual RefPtr<Val> maybeCloneVal(Val* val) override; }; IRGlobalValue* cloneGlobalValue(IRSpecContext* context, IRGlobalValue* originalVal); - RefPtr<Substitutions> cloneSubstitutions( - IRSpecContext* context, - Substitutions* subst); - - RefPtr<Type> IRSpecContext::maybeCloneType(Type* originalType) - { - return originalType->Substitute(subst).As<Type>(); - } - RefPtr<Val> IRSpecContext::maybeCloneVal(Val * val) - { - return val->Substitute(subst); - } + IRInst* cloneValue( + IRSpecContextBase* context, + IRInst* originalValue); + IRType* cloneType( + IRSpecContextBase* context, + IRType* originalType); IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) { - switch (originalValue->op) + if (auto globalValue = as<IRGlobalValue>(originalValue)) { - case kIROp_global_var: - case kIROp_global_constant: - case kIROp_Func: - case kIROp_witness_table: - return cloneGlobalValue(this, (IRGlobalValue*) originalValue); + return cloneGlobalValue(this, globalValue); + } + switch (originalValue->op) + { case kIROp_boolConst: { IRConstant* c = (IRConstant*)originalValue; @@ -4486,70 +4689,43 @@ namespace Slang case kIROp_IntLit: { IRConstant* c = (IRConstant*)originalValue; - return builder->getIntValue(c->type, c->u.intVal); + return builder->getIntValue(cloneType(this, c->getDataType()), c->u.intVal); } break; case kIROp_FloatLit: { IRConstant* c = (IRConstant*)originalValue; - return builder->getFloatValue(c->type, c->u.floatVal); + return builder->getFloatValue(cloneType(this, c->getDataType()), c->u.floatVal); } break; - case kIROp_decl_ref: + default: { - IRDeclRef* od = (IRDeclRef*)originalValue; - auto newDeclRef = od->declRef; + // In the deafult case, assume that we have some sort of "hoistable" + // instruction that requires us to create a clone of it. - // if the declRef is one of the __generic_param decl being substituted by subst - // return the substituted decl - if (subst.globalGenParamSubstitutions) + UInt argCount = originalValue->getOperandCount(); + IRInst* clonedValue = createInstWithTrailingArgs<IRInst>( + builder, + originalValue->op, + cloneType(this, originalValue->getFullType()), + 0, nullptr, + argCount, nullptr); + registerClonedValue(this, clonedValue, originalValue); + for (UInt aa = 0; aa < argCount; ++aa) { - int diff = 0; - newDeclRef = od->declRef.SubstituteImpl(subst, &diff); - for (auto globalGenSubst = subst.globalGenParamSubstitutions; globalGenSubst; globalGenSubst = globalGenSubst->outer) - { - if (!globalGenSubst) - continue; - if (newDeclRef.getDecl() == globalGenSubst->paramDecl) - return builder->getTypeVal(globalGenSubst->actualType.As<Type>()); - else if (auto genConstraint = newDeclRef.As<GenericTypeConstraintDecl>()) - { - // a decl-ref to GenericTypeConstraintDecl as a result of - // referencing a generic parameter type should be replaced with - // the actual witness table - if (genConstraint.getDecl()->ParentDecl == globalGenSubst->paramDecl) - { - // find the witness table from subst - for (auto witness : globalGenSubst->witnessTables) - { - if (witness.Key->EqualsVal(GetSup(genConstraint))) - { - auto proxyVal = witness.Value.As<IRProxyVal>(); - SLANG_ASSERT(proxyVal); - return proxyVal->inst.get(); - } - } - } - } - } + IRInst* originalArg = originalValue->getOperand(aa); + IRInst* clonedArg = cloneValue(this, originalArg); + clonedValue->getOperands()[aa].init(clonedValue, clonedArg); } - auto declRef = maybeCloneDeclRef(newDeclRef); - return builder->getDeclRefVal(declRef); - } - break; - case kIROp_TypeType: - { - IRInst* od = (IRInst*)originalValue; - int ioDiff = 0; - auto newType = od->type->SubstituteImpl(subst, &ioDiff); - return builder->getTypeVal(newType.As<Type>()); + cloneDecorations(this, clonedValue, originalValue); + + addHoistableInst(builder, clonedValue); + + return clonedValue; } break; - default: - SLANG_UNEXPECTED("no value registered for IR value"); - UNREACHABLE_RETURN(nullptr); } } @@ -4557,102 +4733,41 @@ namespace Slang IRSpecContextBase* context, IRInst* originalValue); - RefPtr<Val> cloneSubstitutionArg( - IRSpecContext* context, - Val* val) + // Find a pre-existing cloned value, or return null if none is available. + IRInst* findClonedValue( + IRSpecContextBase* context, + IRInst* originalValue) { - if (auto proxyVal = dynamic_cast<IRProxyVal*>(val)) - { - auto newIRVal = cloneValue(context, proxyVal->inst.get()); - - RefPtr<IRProxyVal> newProxyVal = new IRProxyVal(); - newProxyVal->inst.init(nullptr, newIRVal); - return newProxyVal; - } - else if (auto type = dynamic_cast<Type*>(val)) - { - return context->maybeCloneType(type); - } - else + IRInst* clonedValue = nullptr; + for (auto env = context->getEnv(); env; env = env->parent) { - return context->maybeCloneVal(val); + if (env->clonedValues.TryGetValue(originalValue, clonedValue)) + { + return clonedValue; + } } - } - RefPtr<GenericSubstitution> cloneGenericSubst(IRSpecContext* context, GenericSubstitution* genSubst) - { - if (!genSubst) - return nullptr; - - RefPtr<GenericSubstitution> newSubst = new GenericSubstitution(); - newSubst->outer = cloneGenericSubst(context, genSubst->outer); - newSubst->genericDecl = genSubst->genericDecl; - - for (auto arg : genSubst->args) - { - auto newArg = cloneSubstitutionArg(context, arg); - newSubst->args.Add(newArg); - } - return newSubst; + return nullptr; } - RefPtr<GlobalGenericParamSubstitution> cloneGlobalGenericSubst(IRSpecContext* context, GlobalGenericParamSubstitution* subst) + IRInst* cloneValue( + IRSpecContextBase* context, + IRInst* originalValue) { - if (!subst) + if (!originalValue) return nullptr; - auto newSubst = new GlobalGenericParamSubstitution(); - newSubst->actualType = subst->actualType; - newSubst->paramDecl = subst->paramDecl; - newSubst->witnessTables = subst->witnessTables; - newSubst->outer = cloneGlobalGenericSubst(context, subst->outer); - return newSubst; - } - SubstitutionSet cloneSubstitutions( - IRSpecContext* context, - SubstitutionSet subst) - { - SubstitutionSet rs; - if (!subst) - return rs; - rs.genericSubstitutions = cloneGenericSubst(context, subst.genericSubstitutions); - rs.globalGenParamSubstitutions = cloneGlobalGenericSubst(context, subst.globalGenParamSubstitutions); - if (auto thisSubst = subst.thisTypeSubstitution) - { - RefPtr<ThisTypeSubstitution> newSubst = new ThisTypeSubstitution(); - newSubst->sourceType = thisSubst->sourceType; - rs.thisTypeSubstitution = newSubst; - } - return rs; - } - - DeclRef<Decl> IRSpecContext::maybeCloneDeclRef(DeclRef<Decl> const& declRef) - { - // Un-specialized decl? Nothing to do. - if (!declRef.substitutions) - return declRef; - - DeclRef<Decl> newDeclRef = declRef; - - // Scan through substitutions and clone as needed. - // - // TODO: this is wasteful since we clone *everything* - newDeclRef.substitutions = cloneSubstitutions(this, declRef.substitutions); + if (IRInst* clonedValue = findClonedValue(context, originalValue)) + return clonedValue; - return newDeclRef; + return context->maybeCloneValue(originalValue); } - IRInst* cloneValue( + IRType* cloneType( IRSpecContextBase* context, - IRInst* originalValue) + IRType* originalType) { - IRInst* clonedValue = nullptr; - if (context->getClonedValues().TryGetValue(originalValue, clonedValue)) - { - return clonedValue; - } - - return context->maybeCloneValue(originalValue); + return (IRType*)cloneValue(context, originalType); } IRInst* maybeCloneValueWithMangledName( @@ -4670,50 +4785,19 @@ namespace Slang } return cloneValue(context, originalValue); } - - void cloneInst( + + IRInst* cloneInst( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* originalInst, + IROriginalValuesForClone const& originalValues); + + IRInst* cloneInst( IRSpecContextBase* context, - IRBuilder* builder, - IRInst* originalInst) + IRBuilder* builder, + IRInst* originalInst) { - switch (originalInst->op) - { - // TODO: are there any instruction types that need to be handled - // specially here? That would be anything that has more state - // than is visible in its operand list... - case 0: // nothing yet - default: - { - // The common case is that we just need to construct a cloned - // instruction with the right number of operands, intialize - // it, and then add it to the sequence. - UInt argCount = originalInst->getOperandCount(); - IRInst* clonedInst = createInstWithTrailingArgs<IRInst>( - builder, originalInst->op, - context->maybeCloneType(originalInst->type), - 0, nullptr, - argCount, nullptr); - registerClonedValue(context, clonedInst, originalInst); - auto oldBuilder = context->builder; - context->builder = builder; - for (UInt aa = 0; aa < argCount; ++aa) - { - IRInst* originalArg = originalInst->getOperand(aa); - IRInst* clonedArg; - if (originalArg->op == kIROp_witness_table) - clonedArg = cloneGlobalValueWithMangledName((IRSpecContext*)context, - ((IRGlobalValue*)originalArg)->mangledName, (IRGlobalValue*)originalArg); - else - clonedArg = cloneValue(context, originalArg); - clonedInst->getOperands()[aa].init(clonedInst, clonedArg); - } - builder->addInst(clonedInst); - context->builder = oldBuilder; - cloneDecorations(context, clonedInst, originalInst); - } - - break; - } + return cloneInst(context, builder, originalInst, originalInst); } void cloneGlobalValueWithCodeCommon( @@ -4722,17 +4806,18 @@ namespace Slang IRGlobalValueWithCode* originalValue); IRGlobalVar* cloneGlobalVarImpl( - IRSpecContext* context, - IRGlobalVar* originalVar, + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalVar* originalVar, IROriginalValuesForClone const& originalValues) { - auto clonedVar = context->builder->createGlobalVar( - context->maybeCloneType(originalVar->getDataType()->getValueType())); + auto clonedVar = builder->createGlobalVar( + cloneType(context, originalVar->getDataType()->getValueType())); if(auto rate = originalVar->getRate() ) { - clonedVar->type = context->builder->getSession()->getRateQualifiedType( - rate, clonedVar->type); + clonedVar->setFullType(builder->getRateQualifiedType( + rate, clonedVar->getFullType())); } registerClonedValue(context, clonedVar, originalValues); @@ -4745,7 +4830,7 @@ namespace Slang VarLayout* layout = nullptr; if (context->globalVarLayouts.TryGetValue(mangledName, layout)) { - context->builder->addLayoutDecoration(clonedVar, layout); + builder->addLayoutDecoration(clonedVar, layout); } // Clone any code in the body of the variable, since this @@ -4759,11 +4844,13 @@ namespace Slang } IRGlobalConstant* cloneGlobalConstantImpl( - IRSpecContext* context, - IRGlobalConstant* originalVal, + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalConstant* originalVal, IROriginalValuesForClone const& originalValues) { - auto clonedVal = context->builder->createGlobalConstant(context->maybeCloneType(originalVal->getFullType())); + auto clonedVal = builder->createGlobalConstant( + cloneType(context, originalVal->getFullType())); registerClonedValue(context, clonedVal, originalValues); auto mangledName = originalVal->mangledName; @@ -4781,48 +4868,111 @@ namespace Slang return clonedVal; } - IRWitnessTable* cloneWitnessTableImpl( - IRSpecContextBase* context, - IRWitnessTable* originalTable, + IRGeneric* cloneGenericImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGeneric* originalVal, + IROriginalValuesForClone const& originalValues) + { + auto clonedVal = builder->emitGeneric(); + registerClonedValue(context, clonedVal, originalValues); + + auto mangledName = originalVal->mangledName; + clonedVal->mangledName = mangledName; + + cloneDecorations(context, clonedVal, originalVal); + + // Clone any code in the body of the generic, since this + // computes its result value. + cloneGlobalValueWithCodeCommon( + context, + clonedVal, + originalVal); + + return clonedVal; + } + + void cloneSimpleGlobalValueImpl( + IRSpecContextBase* context, + IRGlobalValue* originalInst, IROriginalValuesForClone const& originalValues, - IRWitnessTable* dstTable = nullptr, - bool registerValue = true) + IRGlobalValue* clonedInst, + bool registerValue = true) { - auto clonedTable = dstTable ? dstTable : context->builder->createWitnessTable(); if (registerValue) - registerClonedValue(context, clonedTable, originalValues); + registerClonedValue(context, clonedInst, originalValues); - auto mangledName = originalTable->mangledName; - - clonedTable->mangledName = mangledName; - clonedTable->genericDecl = originalTable->genericDecl; - clonedTable->subTypeDeclRef = originalTable->subTypeDeclRef; - clonedTable->supTypeDeclRef = originalTable->supTypeDeclRef; - cloneDecorations(context, clonedTable, originalTable); + auto mangledName = originalInst->mangledName; + clonedInst->mangledName = mangledName; - // Clone the entries in the witness table as well - for(auto originalEntry : originalTable->getEntries() ) - { - auto clonedKey = cloneValue(context, originalEntry->requirementKey.get()); - - // if a global val with the mangled name already exists, don't clone again - auto clonedVal = maybeCloneValueWithMangledName(context, (IRGlobalValue*)(originalEntry->satisfyingVal.get())); + cloneDecorations(context, clonedInst, originalInst); - /*auto clonedEntry = */context->builder->createWitnessTableEntry( - clonedTable, - clonedKey, - clonedVal); + // Set up an IR builder for inserting into the inst + IRBuilder builderStorage = *context->builder; + IRBuilder* builder = &builderStorage; + builder->setInsertInto(clonedInst); + + // Clone any children of the instruction + for (auto child : originalInst->getChildren()) + { + cloneInst(context, builder, child); } + } + IRStructKey* cloneStructKeyImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRStructKey* originalVal, + IROriginalValuesForClone const& originalValues) + { + auto clonedVal = builder->createStructKey(); + cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); + return clonedVal; + } + + IRGlobalGenericParam* cloneGlobalGenericParamImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalGenericParam* originalVal, + IROriginalValuesForClone const& originalValues) + { + auto clonedVal = builder->emitGlobalGenericParam(); + cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); + return clonedVal; + } + + + IRWitnessTable* cloneWitnessTableImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRWitnessTable* originalTable, + IROriginalValuesForClone const& originalValues, + IRWitnessTable* dstTable = nullptr, + bool registerValue = true) + { + auto clonedTable = dstTable ? dstTable : builder->createWitnessTable(); + cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); return clonedTable; } IRWitnessTable* cloneWitnessTableWithoutRegistering( IRSpecContextBase* context, + IRBuilder* builder, IRWitnessTable* originalTable, IRWitnessTable* dstTable = nullptr) { - return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone(), dstTable, false); + return cloneWitnessTableImpl(context, builder, originalTable, IROriginalValuesForClone(), dstTable, false); + } + + IRStructType* cloneStructTypeImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRStructType* originalStruct, + IROriginalValuesForClone const& originalValues) + { + auto clonedStruct = builder->createStructType(); + cloneSimpleGlobalValueImpl(context, originalStruct, originalValues, clonedStruct); + return clonedStruct; } void cloneGlobalValueWithCodeCommon( @@ -4887,11 +5037,14 @@ namespace Slang } - void checkIRDuplicate(IRParentInst* moduleInst, Name* mangledName) + void checkIRDuplicate(IRInst* inst, IRParentInst* moduleInst, Name* mangledName) { #ifdef _DEBUG for (auto child : moduleInst->getChildren()) { + if (child == inst) + continue; + if (child->op == kIROp_Func) { auto extName = ((IRGlobalValue*)child)->mangledName; @@ -4902,6 +5055,7 @@ namespace Slang } } #else + SLANG_UNREFERENCED_PARAMETER(inst); SLANG_UNREFERENCED_PARAMETER(moduleInst); SLANG_UNREFERENCED_PARAMETER(mangledName); #endif @@ -4915,9 +5069,7 @@ namespace Slang { // First clone all the simple properties. clonedFunc->mangledName = originalFunc->mangledName; - clonedFunc->genericDecls = originalFunc->genericDecls; - clonedFunc->specializedGenericLevel = originalFunc->specializedGenericLevel; - clonedFunc->type = context->maybeCloneType(originalFunc->type); + clonedFunc->setFullType(cloneType(context, originalFunc->getFullType())); cloneDecorations(context, clonedFunc, originalFunc); @@ -4930,10 +5082,9 @@ namespace Slang // it needs to follow its dependencies. // // TODO: This isn't really a good requirement to place on the IR... - clonedFunc->removeFromParent(); + clonedFunc->moveToEnd(); if (checkDuplicate) - checkIRDuplicate(context->getModule()->getModuleInst(), clonedFunc->mangledName); - clonedFunc->insertAtEnd(context->getModule()->getModuleInst()); + checkIRDuplicate(clonedFunc, context->getModule()->getModuleInst(), clonedFunc->mangledName); } IRFunc* specializeIRForEntryPoint( @@ -5072,17 +5223,51 @@ namespace Slang return result; } + IRInst* findGenericReturnVal(IRGeneric* generic) + { + auto lastBlock = generic->getLastBlock(); + if (!lastBlock) + return nullptr; + + auto returnInst = as<IRReturnVal>(lastBlock->getTerminator()); + if (!returnInst) + return nullptr; + + auto val = returnInst->getVal(); + return val; + } + bool isDefinition( - IRGlobalValue* val) + IRGlobalValue* inVal) { + IRInst* val = inVal; + // unwrap any generic declarations to see + // the value they return. + for(;;) + { + auto genericInst = as<IRGeneric>(val); + if(!genericInst) + break; + + auto returnVal = findGenericReturnVal(genericInst); + if(!returnVal) + break; + + val = returnVal; + } + switch (val->op) { - case kIROp_witness_table: - case kIROp_global_var: - case kIROp_global_constant: + case kIROp_WitnessTable: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: case kIROp_Func: + case kIROp_Generic: return ((IRParentInst*)val)->getFirstChild() != nullptr; + case kIROp_StructType: + return true; + default: return false; } @@ -5146,51 +5331,92 @@ namespace Slang } IRFunc* cloneFuncImpl( - IRSpecContext* context, - IRFunc* originalFunc, + IRSpecContextBase* context, + IRBuilder* builder, + IRFunc* originalFunc, IROriginalValuesForClone const& originalValues) { - auto clonedFunc = context->builder->createFunc(); + auto clonedFunc = builder->createFunc(); registerClonedValue(context, clonedFunc, originalValues); cloneFunctionCommon(context, clonedFunc, originalFunc); return clonedFunc; } - // Directly clone a global value, based on a single definition/declaration, `originalVal`. - // The symbol `sym` will thread together other declarations of the same value, and - // we will register the new value as the cloned version of all of those. - IRGlobalValue* cloneGlobalValueImpl( - IRSpecContext* context, - IRGlobalValue* originalVal, - IRSpecSymbol* sym) - { - if( !originalVal ) - { - SLANG_UNEXPECTED("cloning a null value"); - UNREACHABLE_RETURN(nullptr); - } - switch( originalVal->op ) + IRInst* cloneInst( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* originalInst, + IROriginalValuesForClone const& originalValues) + { + switch (originalInst->op) { + // We need to special-case any instruction that is not + // allocated like an ordinary `IRInst` with trailing args. case kIROp_Func: - return cloneFuncImpl(context, (IRFunc*) originalVal, sym); + return cloneFuncImpl(context, builder, cast<IRFunc>(originalInst), originalValues); + + case kIROp_GlobalVar: + return cloneGlobalVarImpl(context, builder, cast<IRGlobalVar>(originalInst), originalValues); + + case kIROp_GlobalConstant: + return cloneGlobalConstantImpl(context, builder, cast<IRGlobalConstant>(originalInst), originalValues); + + case kIROp_WitnessTable: + return cloneWitnessTableImpl(context, builder, cast<IRWitnessTable>(originalInst), originalValues); - case kIROp_global_var: - return cloneGlobalVarImpl(context, (IRGlobalVar*)originalVal, sym); + case kIROp_StructType: + return cloneStructTypeImpl(context, builder, cast<IRStructType>(originalInst), originalValues); + + case kIROp_Generic: + return cloneGenericImpl(context, builder, cast<IRGeneric>(originalInst), originalValues); - case kIROp_global_constant: - return cloneGlobalConstantImpl(context, (IRGlobalConstant*)originalVal, sym); + case kIROp_StructKey: + return cloneStructKeyImpl(context, builder, cast<IRStructKey>(originalInst), originalValues); - case kIROp_witness_table: - return cloneWitnessTableImpl(context, (IRWitnessTable*)originalVal, sym); + case kIROp_GlobalGenericParam: + return cloneGlobalGenericParamImpl(context, builder, cast<IRGlobalGenericParam>(originalInst), originalValues); default: - SLANG_UNEXPECTED("unknown global value kind"); - UNREACHABLE_RETURN(nullptr); + break; } + // The common case is that we just need to construct a cloned + // instruction with the right number of operands, intialize + // it, and then add it to the sequence. + UInt argCount = originalInst->getOperandCount(); + IRInst* clonedInst = createInstWithTrailingArgs<IRInst>( + builder, originalInst->op, + cloneType(context, originalInst->getFullType()), + 0, nullptr, + argCount, nullptr); + registerClonedValue(context, clonedInst, originalValues); + auto oldBuilder = context->builder; + context->builder = builder; + for (UInt aa = 0; aa < argCount; ++aa) + { + IRInst* originalArg = originalInst->getOperand(aa); + IRInst* clonedArg = cloneValue(context, originalArg); + clonedInst->getOperands()[aa].init(clonedInst, clonedArg); + } + builder->addInst(clonedInst); + context->builder = oldBuilder; + cloneDecorations(context, clonedInst, originalInst); + + return clonedInst; } + IRGlobalValue* cloneGlobalValueImpl( + IRSpecContext* context, + IRGlobalValue* originalInst, + IROriginalValuesForClone const& originalValues) + { + auto clonedValue = cloneInst(context, &context->shared->builderStorage, originalInst, originalValues); + clonedValue->moveToEnd(); + return cast<IRGlobalValue>(clonedValue); + } + + // Clone a global value, which has the given `mangledName`. // The `originalVal` is a known global IR value with that name, if one is available. // (It is okay for this parameter to be null). @@ -5202,7 +5428,7 @@ namespace Slang // If the global value being cloned is already in target module, don't clone // Why checking this? // When specializing a generic function G (which is already in target module), - // where G calls a normal function F (which is already in target module), + // where G calls a normal function F (which is already in target module), // then when we are making a copy of G via cloneFuncCommom(), it will recursively clone F, // however we don't want to make a duplicate of F in the target module. if (originalVal->getParent() == context->getModule()->getModuleInst()) @@ -5210,17 +5436,19 @@ namespace Slang // Check if we've already cloned this value, for the case where // an original value has already been established. - IRInst* clonedVal = nullptr; - if( originalVal && context->getClonedValues().TryGetValue(originalVal, clonedVal) ) + if (originalVal) { - return (IRGlobalValue*) clonedVal; + if (IRInst* clonedVal = findClonedValue(context, originalVal)) + { + return cast<IRGlobalValue>(clonedVal); + } } if(getText(mangledName).Length() == 0) { // If there is no mangled name, then we assume this is a local symbol, // and it can't possibly have multiple declarations. - return cloneGlobalValueImpl(context, originalVal, nullptr); + return cloneGlobalValueImpl(context, originalVal, IROriginalValuesForClone()); } // @@ -5236,7 +5464,7 @@ namespace Slang // This shouldn't happen! SLANG_UNEXPECTED("no matching values registered"); - UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, nullptr)); + UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, IROriginalValuesForClone())); } // We will try to track the "best" declaration we can find. @@ -5256,12 +5484,15 @@ namespace Slang // Check if we've already cloned this value, for the case where // we didn't have an original value (just a name), but we've // now found a representative value. - if( !originalVal && context->getClonedValues().TryGetValue(bestVal, clonedVal) ) + if (!originalVal) { - return (IRGlobalValue*) clonedVal; + if (IRInst* clonedVal = findClonedValue(context, bestVal)) + { + return cast<IRGlobalValue>(clonedVal); + } } - return cloneGlobalValueImpl(context, bestVal, sym); + return cloneGlobalValueImpl(context, bestVal, IROriginalValuesForClone(sym)); } IRGlobalValue* cloneGlobalValueWithMangledName(IRSpecContext* context, Name* mangledName) @@ -5365,11 +5596,6 @@ namespace Slang ProgramLayout* programLayout, SubstitutionSet typeSubst); - RefPtr<GlobalGenericParamSubstitution> createGlobalGenericParamSubstitution( - EntryPointRequest * entryPointRequest, - ProgramLayout * programLayout, - IRSpecContext* context); - struct IRSpecializationState { ProgramLayout* programLayout; @@ -5382,8 +5608,16 @@ namespace Slang IRSharedSpecContext sharedContextStorage; IRSpecContext contextStorage; + IRSpecEnv globalEnv; + IRSharedSpecContext* getSharedContext() { return &sharedContextStorage; } IRSpecContext* getContext() { return &contextStorage; } + + IRSpecializationState() + { + contextStorage.env = &globalEnv; + } + ~IRSpecializationState() { newProgramLayout = nullptr; @@ -5429,19 +5663,27 @@ namespace Slang auto context = state->getContext(); context->shared = sharedContext; context->builder = &sharedContext->builderStorage; - // Create the GlobalGenericParamSubstitution for substituting global generic types - // into user-provided type arguments - auto globalParamSubst = createGlobalGenericParamSubstitution(entryPointRequest, programLayout, context); - context->subst.globalGenParamSubstitutions = globalParamSubst; - - // now specailize the program layout using the substitution - RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout(targetReq, programLayout, context->subst); + // Now specialize the program layout using the substitution + // + // TODO: The specialization of the layout is conceptually an AST-level operations, + // and shouldn't be done here in the IR at all. + // + RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout( + targetReq, + programLayout, + SubstitutionSet(entryPointRequest->globalGenericSubst)); + + // TODO: we need to register the (IR-level) arguments of the global generic parameters as the + // substitutions for the generic parameters in the original IR. + + // applyGlobalGenericParamSubsitution(...); + state->newProgramLayout = newProgramLayout; // Next, we want to optimize lookup for layout infromation - // associated with global declarations, so that we can + // associated with global declarations, so that we can // look things up based on the IR values (using mangled names) auto globalStructLayout = getGlobalStructLayout(newProgramLayout); for (auto globalVarLayout : globalStructLayout->fields) @@ -5453,7 +5695,7 @@ namespace Slang // for now, clone all unreferenced witness tables for (auto sym :context->getSymbols()) { - if (sym.Value->irGlobalValue->op == kIROp_witness_table) + if (sym.Value->irGlobalValue->op == kIROp_WitnessTable) cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); } return state; @@ -5526,6 +5768,20 @@ namespace Slang // it might reference. auto irEntryPoint = specializeIRForEntryPoint(context, entryPointRequest, entryPointLayout); + // HACK: right now the bindings for global generic parameters are coming in + // as part of the original IR module, and we need to make sure these get + // copied over, even if they aren't referenced. + // + for(auto inst : originalIRModule->getGlobalInsts()) + { + auto bindInst = as<IRBindGlobalGenericParam>(inst); + if(!bindInst) + continue; + + cloneValue(context, bindInst); + } + + // TODO: *technically* we should consider the case where // we have global variables with initializers, since // these should get run whether or not the entry point @@ -5551,7 +5807,7 @@ namespace Slang break; } } - + struct IRGenericSpecContext : IRSpecContextBase { IRSpecContextBase* parent = nullptr; @@ -5560,383 +5816,69 @@ namespace Slang // Override the "maybe clone" logic so that we always clone virtual IRInst* maybeCloneValue(IRInst* originalVal) override; - - virtual RefPtr<Type> maybeCloneType(Type* originalType) override; - virtual RefPtr<Val> maybeCloneVal(Val* val) override; }; - // Convert a type-level value into an IR-level equivalent. - IRInst* getIRValue( - IRGenericSpecContext* context, - Val* val) + IRInst* IRGenericSpecContext::maybeCloneValue(IRInst* originalVal) { - if( auto subtypeWitness = dynamic_cast<SubtypeWitness*>(val) ) - { - auto mangledName = context->getModule()->session->getNameObj(getMangledNameForConformanceWitness( - subtypeWitness->sub, - subtypeWitness->sup)); - RefPtr<IRSpecSymbol> symbol; - - if (context->getSymbols().TryGetValue(mangledName, symbol)) - { - // Note: the symbols always come from the source module, - // not the destination module, so we may need to clone - // them if we are doing an initialize specialization pass. - return cloneValue(context, symbol->irGlobalValue); - } - else - { - // we don't have the required witness table yet, - // try to emit a specialize instruction to get one - auto subDeclRef = subtypeWitness->sub->AsDeclRefType(); - auto subDeclRefGen = DeclRef<Decl>(subDeclRef->declRef.decl, - createDefaultSubstitutions(context->builder->getSession(), subDeclRef->declRef.decl)); - - auto genericName = context->getModule()->session->getNameObj(getMangledNameForConformanceWitness( - subDeclRefGen, - subtypeWitness->sup)); - if (context->getSymbols().TryGetValue(genericName, symbol)) - { - auto clonedSymbol = cloneValue(context, symbol->irGlobalValue); - auto specInst = context->builder->emitSpecializeInst(subtypeWitness->sup, clonedSymbol, subDeclRef->declRef); - return specInst; - } - else - { - SLANG_UNEXPECTED("witness table not exist"); - UNREACHABLE_RETURN(nullptr); - } - } - } - else if (auto intVal = dynamic_cast<ConstantIntVal*>(val)) + if (parent) { - return context->builder->getIntValue(context->shared->originalModule->session->getBuiltinType(BaseType::Int), intVal->value); - } - else if (auto proxyVal = dynamic_cast<IRProxyVal*>(val)) - { - // The type-level value actually references an IR-level value, - // so we need to make sure to emit as if we were referencing - // the pointed-to value and not the proxy type-level `Val` - // instead. - - return context->maybeCloneValue(proxyVal->inst.get()); + return parent->maybeCloneValue(originalVal); } else { - SLANG_UNEXPECTED("unimplemented"); - UNREACHABLE_RETURN(nullptr); + return originalVal; } } - IRInst* getSubstValue( - IRGenericSpecContext* context, - DeclRef<Decl> declRef) + // See the work list for the generic spec context with + // every relevant instruction from `inst` through its + // descendents. + void addToSpecializationWorkListRec( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) { - auto subst = context->subst.genericSubstitutions; - SLANG_ASSERT(subst); - auto genericDecl = subst->genericDecl; - - UInt orinaryParamCount = 0; - for( auto mm : genericDecl->Members ) + if(auto genericInst = as<IRGeneric>(inst)) { - if(mm.As<GenericTypeParamDecl>()) - orinaryParamCount++; - else if(mm.As<GenericValueParamDecl>()) - orinaryParamCount++; + // We do *not* consider generics, or instructions nested under them. + return; } - - if( auto constraintDeclRef = declRef.As<GenericTypeConstraintDecl>() ) + else if(auto parentInst = as<IRParentInst>(inst)) { - // We have a constraint, but we need to find its index in the - // argument list of the substitutions. - UInt constraintIndex = 0; - bool found = false; - for( auto cd : genericDecl->getMembersOfType<GenericTypeConstraintDecl>() ) - { - if( cd.Ptr() == constraintDeclRef.getDecl() ) - { - found = true; - break; - } - - constraintIndex++; - } - assert(found); + // For a parent instruction, we will scan through its contents, + // since that will be where the `specialize` instructions are - UInt argIndex = orinaryParamCount + constraintIndex; - assert(argIndex < subst->args.Count()); - - return getIRValue(context, subst->args[argIndex]); - } - else if (auto valDeclRef = declRef.As<GenericValueParamDecl>()) - { - // We have a constraint, but we need to find its index in the - // argument list of the substitutions. - UInt argIdx = 0; - bool found = false; - for (auto cd : genericDecl->Members) + for(auto child : parentInst->children) { - if (cd.Ptr() == valDeclRef.getDecl()) - { - found = true; - break; - } - if (cd.As<GenericTypeParamDecl>()) - argIdx++; - else if (cd.As<GenericValueParamDecl>()) - argIdx++; + addToSpecializationWorkListRec(sharedContext, child); } - assert(found); - - assert(argIdx < subst->args.Count()); - - return getIRValue(context, subst->args[argIdx]); } else { - SLANG_UNEXPECTED("unimplemented"); - UNREACHABLE_RETURN(nullptr); - } - } - - IRInst* IRGenericSpecContext::maybeCloneValue(IRInst* originalVal) - { - switch( originalVal->op ) - { - case kIROp_decl_ref: - { - auto declRefVal = (IRDeclRef*) originalVal; - auto declRef = declRefVal->declRef; - auto genSubst = subst.genericSubstitutions; - SLANG_ASSERT(genSubst); - // We may have a direct reference to one of the parameters - // of the generic we are specializing, and in that case - // we nee to translate it over to the equiavalent of - // the `Val` we have been given. - if(declRef.getDecl()->ParentDecl == genSubst->genericDecl && - (declRef.As<GenericTypeParamDecl>() || declRef.As<GenericValueParamDecl>()|| - declRef.As<GenericTypeConstraintDecl>())) - { - if (auto substVal = getSubstValue(this, declRef)) - return substVal; - } - int diff = 0; - auto substDeclRef = declRefVal->declRef.SubstituteImpl(subst, &diff); - if(!diff) - return originalVal; - - return builder->getDeclRefVal(substDeclRef); - } - break; - - default: - if (parent) - { - return parent->maybeCloneValue(originalVal); - } - else - { - return originalVal; - } - } - } - - RefPtr<Type> IRGenericSpecContext::maybeCloneType(Type* originalType) - { - return originalType->Substitute(subst).As<Type>(); - } - - RefPtr<Val> IRGenericSpecContext::maybeCloneVal(Val * val) - { - return val->Substitute(subst); - } - - // Given a list of substitutions, return the inner-most - // generic substitution in the list, or NULL if there - // are no generic substitutions. - RefPtr<GenericSubstitution> getInnermostGenericSubst( - SubstitutionSet inSubst) - { - return inSubst.genericSubstitutions; - } - - RefPtr<GenericDecl> getInnermostGenericDecl( - Decl* inDecl) - { - auto decl = inDecl; - while( decl ) - { - GenericDecl* genericDecl = dynamic_cast<GenericDecl*>(decl); - if(genericDecl) - return genericDecl; - - decl = decl->ParentDecl; + // Default case: consider this instruction for specialization. + sharedContext->addToWorkList(inst); } - return nullptr; } - // This function takes a list of substitutions that we'd - // like to apply, but which might apply to a different - // declaration in cases where we have got target-specific - // overloads in the mix, and produces a new set of - // substitutiosn without this issue. - RefPtr<GenericSubstitution> cloneSubstitutionsForSpecialization( - IRSharedSpecContext* sharedContext, - RefPtr<GenericSubstitution> oldSubst, - Decl* newDecl) - { - // We will "peel back" layers of substitutions until - // we find our first generic subsitution. - auto oldGenericSubst = oldSubst; - if(!oldGenericSubst) - return nullptr; - - auto innerGenericName = oldGenericSubst->genericDecl->inner->getName(); - - // We will also peel back layers of declarations until - // we find our first generic decl. - GenericDecl* newGenericDecl = nullptr; - - for (Decl* d = newDecl; d; d = d->ParentDecl) - { - if (auto gd = dynamic_cast<GenericDecl*>(d)) - { - if (gd->inner->getName() == innerGenericName) - { - newGenericDecl = gd; - break; - } - } - } - - if( !newGenericDecl ) - { - if(auto gd = dynamic_cast<GenericDecl*>(newDecl)) - { - if( auto ed = gd->inner.As<ExtensionDecl>() ) - { - // TODO: we should confirm that it is an extension for the correct type... - - newGenericDecl = gd; - } - } - } - - SLANG_ASSERT(newGenericDecl); - - RefPtr<GenericSubstitution> newSubst = new GenericSubstitution(); - newSubst->genericDecl = newGenericDecl; - newSubst->args = oldGenericSubst->args; - - newSubst->outer = cloneSubstitutionsForSpecialization( - sharedContext, - oldGenericSubst->outer, - newGenericDecl->ParentDecl); - - return newSubst; - } - - IRFunc* getSpecializedFunc( - IRSharedSpecContext* sharedContext, - IRSpecContextBase* parentContext, - IRFunc* genericFunc, - DeclRef<Decl> specDeclRef); - - IRWitnessTable* specializeWitnessTable( - IRSharedSpecContext* sharedContext, - IRSpecContextBase* parentContext, - IRWitnessTable* originalTable, - DeclRef<Decl> specDeclRef, - IRWitnessTable* dstTable) + IRInst* specializeGeneric( + IRSharedGenericSpecContext* sharedContext, + IRSpecContextBase* parentContext, + IRGeneric* genericVal, + IRSpecialize* specializeInst) { // First, we want to see if an existing specialization // has already been made. To do that we will need to - // compute the mangled name of the specialized function, + // compute the mangled name of the specialized value, // so that we can look for existing declarations. - String specializedMangledName = getMangledNameForConformanceWitness(specDeclRef.Substitute(originalTable->subTypeDeclRef), - specDeclRef.Substitute(originalTable->supTypeDeclRef)); - - if (dstTable && getText(dstTable->mangledName).Length()) - specializedMangledName = getText(dstTable->mangledName); - - // TODO: This is a terrible linear search, and we should - // avoid it by building a dictionary ahead of time, - // as is being done for the `IRSpecContext` used above. - // We can probalby use the same basic context, actually. - if (!dstTable) - { - auto module = sharedContext->module; - for(auto ii : module->getGlobalInsts()) - { - auto gv = as<IRGlobalValue>(ii); - if (!gv) - continue; - - if (getText(gv->mangledName) == specializedMangledName) - return (IRWitnessTable*)gv; - } - } - RefPtr<GenericSubstitution> newSubst = cloneSubstitutionsForSpecialization( - sharedContext, - specDeclRef.substitutions.genericSubstitutions, - originalTable->genericDecl); - - IRGenericSpecContext context; - context.shared = sharedContext; - context.parent = parentContext; - context.builder = &sharedContext->builderStorage; - context.subst = specDeclRef.substitutions; - context.subst.genericSubstitutions = newSubst; - // TODO: other initialization is needed here... - - auto specTable = cloneWitnessTableWithoutRegistering(&context, originalTable, dstTable); - - // Set up the clone to recognize that it is no longer generic - specTable->mangledName = context.getModule()->session->getNameObj(specializedMangledName); - specTable->genericDecl = nullptr; - - // Specialization of witness tables should trigger cascading specializations - // of involved functions. - for (auto entry : specTable->getEntries()) - { - if (entry->satisfyingVal.get()->op == kIROp_Func) - { - IRFunc* func = (IRFunc*)entry->satisfyingVal.get(); - auto specFunc = getSpecializedFunc(sharedContext, parentContext, func, specDeclRef); - entry->satisfyingVal.set(specFunc); - insertGlobalValueSymbol(sharedContext, specFunc); - } - - } - // We also need to make sure that we register this specialized - // function under its mangled name, so that later lookup - // steps will find it. - insertGlobalValueSymbol(sharedContext, specTable); - - return specTable; - } - - IRFunc* getSpecializedFunc( - IRSharedSpecContext* sharedContext, - IRSpecContextBase* parentContext, - IRFunc* genericFunc, - DeclRef<Decl> specDeclRef) - { - // First, we want to see if an existing specialization - // has already been made. To do that we will need to - // compute the mangled name of the specialized function, - // so that we can look for existing declarations. - String specMangledName; - if (genericFunc->getGenericDecl() == specDeclRef.decl) - specMangledName = getMangledName(specDeclRef); - else - specMangledName = mangleSpecializedFuncName(getText(genericFunc->mangledName), specDeclRef.substitutions); + String specMangledName = mangleSpecializedFuncName(getText(genericVal->mangledName), specializeInst); auto specMangledNameObj = sharedContext->module->session->getNameObj(specMangledName); + + // Now look up an existing symbol with a matching name RefPtr<IRSpecSymbol> symb; if (sharedContext->symbols.TryGetValue(specMangledNameObj, symb)) { - return (IRFunc*)(symb->irGlobalValue); + return symb->irGlobalValue; } + // TODO: This is a terrible linear search, and we should // avoid it by building a dictionary ahead of time, // as is being done for the `IRSpecContext` used above. @@ -5948,104 +5890,285 @@ namespace Slang continue; if (gv->mangledName == specMangledNameObj) - return (IRFunc*) gv; + return gv; } // If we get to this point, then we need to construct a - // new `IRFunc` to represent the result of specialization. + // new IR value to represent the result of specialization. - // The substitutions we are applying might have been created - // using a different overload of a target-specific function, - // so we need to create a dummy substitution here, to make - // sure it used the correct generic. - RefPtr<GenericSubstitution> newSubst = cloneSubstitutionsForSpecialization( - sharedContext, - specDeclRef.substitutions.genericSubstitutions, - genericFunc->getGenericDecl()); + // We need to establish a new mapping from inst->inst to + // handle the specialization, because we don't want the + // clones we register in this pass to cause confusion + // in later steps that might clone the same code. + + IRSpecEnv env; + env.parent = &sharedContext->globalEnv; + if (parentContext) + { + env.parent = parentContext->getEnv(); + } - if (!newSubst) - return genericFunc; + // The result of specialization should be inserted + // into the global scope, at the same location as + // the original generic. + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(genericVal); IRGenericSpecContext context; context.shared = sharedContext; context.parent = parentContext; - context.builder = &sharedContext->builderStorage; - context.subst = specDeclRef.substitutions; - context.subst.genericSubstitutions = newSubst; + context.builder = builder; + context.env = &env; - // TODO: other initialization is needed here... + // Register the arguments of the `specialize` instruction to be used + // as the "cloned" value for each of the parameters of the generic. + // + UInt argCounter = 0; + for (auto param = genericVal->getFirstParam(); param; param = param->getNextParam()) + { + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < specializeInst->getArgCount()); - auto specFunc = cloneSimpleFuncWithoutRegistering(&context, genericFunc); + IRInst* arg = specializeInst->getArg(argIndex); - specFunc->mangledName = context.getModule()->session->getNameObj(specMangledName); - - // reduce specialized generic level by 1 - if (specFunc->specializedGenericLevel >= 0) - specFunc->specializedGenericLevel--; + registerClonedValue(&context, arg, param); + } - // Put the function into the global sequence right after - // the function it specializes. - // - // TODO: This shouldn't be needed, if we introduce a sorting - // step before we emit code. - //specFunc->removeFromParent(); - //specFunc->insertAfter(genericFunc); + // Okay, now we want to run through the body of the generic + // and clone stuff into the parent scope (which had + // better be the global scope). + for (auto bb : genericVal->getBlocks()) + { + // We expect a generic to only ever contain a single block. + SLANG_ASSERT(bb == genericVal->getFirstBlock()); - // At this point we've created a new non-generic function, - // which means we should add it to our work list for - // subsequent processing. - if (specFunc->specializedGenericLevel == -1) - sharedContext->workList.Add(specFunc); + for (auto ii : bb->getChildren()) + { + // Skip parameters, since they were handled earlier. + if (auto param = as<IRParam>(ii)) + continue; + + // The last block of the generic is expected to end with + // a `return` instruction for the specialized value that + // comes out of the abstraction. + // + // We thus use that cloned value as the result of the + // specialization step. + if (auto returnValInst = as<IRReturnVal>(ii)) + { + auto clonedResult = cloneValue(&context, returnValInst->getVal()); + if (auto clonedGlobalValue = as<IRGlobalValue>(clonedResult)) + { + clonedGlobalValue->mangledName = specMangledNameObj; + + // TODO: create a symbol for it and add it to the map. + } + + return clonedResult; + } - // We also need to make sure that we register this specialized - // function under its mangled name, so that later lookup - // steps will find it. - insertGlobalValueSymbol(sharedContext, specFunc); + // Otherwise, clone the instruction into the global scope + IRInst* clonedInst = cloneInst(&context, context.builder, ii); - return specFunc; + // Now that we've cloned the instruction to a location outside + // of a generic, we should consider whether it can now be specialized. + addToSpecializationWorkListRec(sharedContext, clonedInst); + } + } + + // If we reach this point, something went wrong, because we + // never encountered a `return` inside the body of the generic. + SLANG_UNEXPECTED("no return from generic"); + UNREACHABLE_RETURN(nullptr); } // Find the value in the given witness table that // satisfies the given requirement (or return // null if not found). IRInst* findWitnessVal( - IRWitnessTable* witnessTable, - DeclRef<Decl> const& requirementDeclRef) + IRWitnessTable* witnessTable, + IRInst* requirementKey) { // For now we will do a dumb linear search for( auto entry : witnessTable->getEntries() ) { - // We expect the key on the entry to be a decl-ref, - // but lets go ahead and check, just to be sure. - auto requirementKey = entry->requirementKey.get(); - if(requirementKey->op != kIROp_decl_ref) + // If the keys matched, then we use the value from this entry. + if (requirementKey == entry->requirementKey.get()) + { + auto satisfyingVal = entry->satisfyingVal.get(); + return satisfyingVal; + } + } + + // No matching entry found. + return nullptr; + } + + static bool canSpecializeGeneric( + IRGeneric* generic) + { + IRGeneric* g = generic; + for(;;) + { + auto val = findGenericReturnVal(g); + if(!val) + return false; + + if (auto nestedGeneric = as<IRGeneric>(val)) + { + // The outer generic returns an *inner* generic + // (so that multiple calls to `specialize` are + // needed to resolve it). We should look at + // what the nested generic returns to figure + // out whether specialization is allowed. + g = nestedGeneric; continue; - auto keyDeclRef = ((IRDeclRef*) requirementKey)->declRef; + } - // If the keys don't match, continue with the next entry. - if (!keyDeclRef.Equals(requirementDeclRef)) + // We've found the leaf value that will be produced after + // all of the specialization is done. Now we want to know + // if that is a value suitable for actually specializing + + if (auto globalValue = as<IRGlobalValue>(val)) { - // requirementDeclRef may be pointing to the inner decl of a generic decl - // in this case we compare keyDeclRef against the parent decl of requiredDeclRef - if (auto genRequiredDeclRef = requirementDeclRef.GetParent().As<GenericDecl>()) + if (isDefinition(globalValue)) + return true; + return false; + } + else + { + // There might be other cases with a declaration-vs-definition + // thing that we need to handle. + + return true; + } + } + } + + // Add any instruction that uses `inst` to the work list, + // so that it can be evaluated (or re-evaluated) for specialization. + void addUsesToWorkList( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + for(auto u = inst->firstUse; u; u = u->nextUse) + { + sharedContext->addToWorkList(u->getUser()); + } + } + + void specializeGenericsForInst( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + switch(inst->op) + { + default: + // The default behavior is to do nothing. + // An instruction is specialize-able once its operands + // are specialized, and after that it is also safe + // to consider the instruction specialized. + break; + + case kIROp_Specialize: + { + // We have a `specialize` instruction, so lets see + // whether we have an opportunity to perform the + // specialization here and now. + IRSpecialize* specInst = cast<IRSpecialize>(inst); + + // Look at the base of the `specialize`, and see if + // it directly names a generic, so that we can apply + // specialization here and now. + auto baseVal = specInst->getBase(); + if(auto genericVal = as<IRGeneric>(baseVal)) { - if (!keyDeclRef.Equals(genRequiredDeclRef)) + if (canSpecializeGeneric(genericVal)) { - continue; + // Okay, we have a candidate for specialization here. + // + // We will apply the specialization logic to the body of the generic, + // which will yield, e.g., a specialized `IRFunc`. + // + auto specializedVal = specializeGeneric(sharedContext, nullptr, genericVal, specInst); + // + // Then we will replace the use sites for the `specialize` + // instruction with uses of the specialized value. + // + addUsesToWorkList(sharedContext, specInst); + specInst->replaceUsesWith(specializedVal); + specInst->removeAndDeallocate(); } } - else - continue; } + break; + + case kIROp_lookup_interface_method: + { + // We have a `lookup_interface_method` instruction, + // so let's see whether it is a lookup in a known + // witness table. + IRLookupWitnessMethod* lookupInst = cast<IRLookupWitnessMethod>(inst); + + // We only want to deal with the case where the witness-table + // argument points to a concrete global table (and not, e.g., a + // `specialize` instruction that will yield a table) + auto witnessTable = as<IRWitnessTable>(lookupInst->witnessTable.get()); + if(!witnessTable) + break; + + // Use the witness table to look up the value that + // satisfies the requirement. + auto requirementKey = lookupInst->getRequirementKey(); + auto satisfyingVal = findWitnessVal(witnessTable, requirementKey); + // We expect to always find something, but lets just + // be careful here. + if(!satisfyingVal) + break; - // If the keys matched, then we use the value from - // this entry. - auto satisfyingVal = entry->satisfyingVal.get(); - return satisfyingVal; + // If we get through all of the above checks, then we + // have a (more) concrete method that implements the interface, + // and so we should dispatch to that directly, rather than + // use the `lookup_interface_method` instruction. + addUsesToWorkList(sharedContext, lookupInst); + lookupInst->replaceUsesWith(satisfyingVal); + lookupInst->removeAndDeallocate(); + } + break; } + } - // No matching entry found. - return nullptr; + static bool isInstSpecialized( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + // If an instruction is still on our work list, then + // it isn't specialized, and conversely we say that + // if it *isn't* on the work list, it must be specialized. + // + // Note: if we end up with bugs in this logic, we could + // maintain an explicit set of specialized insts instead. + // + return !sharedContext->workListSet.Contains(inst); + } + + static bool canSpecializeInst( + IRSharedGenericSpecContext* sharedContext, + IRInst* inst) + { + // We can specialize an instruction once all its + // operands are specialized. + + UInt operandCount = inst->getOperandCount(); + for(UInt ii = 0; ii < operandCount; ++ii) + { + IRInst* operand = inst->getOperand(ii); + if(!isInstSpecialized(sharedContext, operand)) + return false; + } + return true; } // Go through the code in the module and try to identify @@ -6056,7 +6179,7 @@ namespace Slang IRModule* module, CodeGenTarget target) { - IRSharedSpecContext sharedContextStorage; + IRSharedGenericSpecContext sharedContextStorage; auto sharedContext = &sharedContextStorage; initializeSharedSpecContext( @@ -6066,351 +6189,127 @@ namespace Slang module, target); - // Our goal here is to find `specialize` instructions that - // can be replaced with references to a suitably sepcialized - // funciton. As a simplification, we will only consider `specialize` - // calls that are inside of non-generic functions, since we assume - // that these will allow us to fully specialize the referenced - // function. - // - // We start by building up a work list of non-generic functions. - for(auto ii : module->getGlobalInsts()) - { - auto gv = as<IRGlobalValue>(ii); - if (!gv) - continue; + auto moduleInst = module->getModuleInst(); - // Is it a function? If not, skip. - if(gv->op != kIROp_Func) + // First things first, let's deal with any bindings for global generic parameters. + for(auto inst : moduleInst->getChildren()) + { + auto bindInst = as<IRBindGlobalGenericParam>(inst); + if(!bindInst) continue; - auto func = (IRFunc*) gv; - // Is it generic? If so, skip. - if(func->getGenericDecl()) - continue; + auto param = bindInst->getParam(); + auto val = bindInst->getVal(); - sharedContext->workList.Add(func); + param->replaceUsesWith(val); } - - // Build dictionary for witness tables - Dictionary<Name*, IRWitnessTable*> witnessTables; - for(auto ii : module->getGlobalInsts()) { - auto gv = as<IRGlobalValue>(ii); - if (!gv) - continue; - - if (gv->op == kIROp_witness_table) - witnessTables.AddIfNotExists(gv->mangledName, (IRWitnessTable*)gv); - } - - // Now that we have our work list, we are going to - // process it until it goes empty. Along the way - // we may specialize a function and thus create - // a new non-generic function, and in that case - // we will add the new function to the work list. - auto& workList = sharedContext->workList; - while( auto count = workList.Count() ) - { - // We will process the last entry in the - // work list, which amounts to treating - // it like a stack when we have recursive - // specialization to perform. - auto func = workList[count-1]; - workList.RemoveAt(count-1); - - // We are going to go ahead and walk through - // all the instructions in this function, - // and look for `specialize` operations. - for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) + // Now we will do a second pass to clean up the + // generic parameters and their bindings. + IRInst* next = nullptr; + for(auto inst = moduleInst->getFirstChild(); inst; inst = next) { - // We need to be careful when iterating over the instructions, - // because we might end up removing the "current" instruction, - // so that accessing `ii->next` would crash. - IRInst* nextInst = nullptr; - for( auto ii = bb->getFirstInst(); ii; ii = nextInst ) - { - nextInst = ii->getNextInst(); - - // We want to handle both `specialize` instructions, - // which trigger specialization, and also `lookup_interface_method` - // instructions, which may allow us to "de-virtualize" - // calls. - - switch( ii->op ) - { - default: - // Most instructions are ones we don't care about here. - continue; - - case kIROp_specialize: - { - // We have a `specialize` instruction, so lets see - // whether we have an opportunity to perform the - // specialization here and now. - IRSpecialize* specInst = (IRSpecialize*) ii; - - // Now we extract the specialized decl-ref that will - // tell us how to specialize things. - auto specDeclRefVal = (IRDeclRef*)specInst->specDeclRefVal.get(); - auto specDeclRef = specDeclRefVal->declRef; - - // We need to specialize functions and witness tables - auto genericVal = specInst->genericVal.get(); - if (genericVal->op == kIROp_Func) - { - auto genericFunc = (IRFunc*)genericVal; - if (!genericFunc->getGenericDecl()) - continue; - - // Okay, we have a candidate for specialization here. - // - // We will first find or construct a specialized version - // of the callee funciton/ - auto specFunc = getSpecializedFunc(sharedContext, nullptr, genericFunc, specDeclRef); - // - // Then we will replace the use sites for the `specialize` - // instruction with uses of the specialized function. - // - specInst->replaceUsesWith(specFunc); - - specInst->removeAndDeallocate(); - } - else if (genericVal->op == kIROp_witness_table) - { - // specialize a witness table - auto originalTable = (IRWitnessTable*)genericVal; - auto specWitnessTable = specializeWitnessTable(sharedContext, nullptr, originalTable, specDeclRef, nullptr); - witnessTables.AddIfNotExists(specWitnessTable->mangledName, specWitnessTable); - specInst->replaceUsesWith(specWitnessTable); - specInst->removeAndDeallocate(); - } - } - break; - case kIROp_lookup_witness_table: - { - // try find concrete witness table from global scope - IRLookupWitnessTable* lookupInst = (IRLookupWitnessTable*)ii; - IRWitnessTable* witnessTable = nullptr; - auto srcDeclRef = ((IRDeclRef*)lookupInst->sourceType.get())->declRef; - auto interfaceDeclRef = ((IRDeclRef*)lookupInst->interfaceType.get())->declRef; - auto mangledName = module->session->getNameObj(getMangledNameForConformanceWitness(srcDeclRef, interfaceDeclRef)); - witnessTables.TryGetValue(mangledName, witnessTable); - - if (!witnessTable) - { - // try specialize the witness table - auto genDeclRef = srcDeclRef; - genDeclRef.substitutions = createDefaultSubstitutions(module->session, genDeclRef.decl); - auto genName = module->session->getNameObj(getMangledNameForConformanceWitness(genDeclRef, interfaceDeclRef)); - IRWitnessTable* genTable = nullptr; - if (witnessTables.TryGetValue(genName, genTable)) - { - witnessTable = specializeWitnessTable(sharedContext, nullptr, genTable, srcDeclRef, nullptr); - witnessTables.AddIfNotExists(witnessTable->mangledName, witnessTable); - } - } - if (witnessTable) - { - lookupInst->replaceUsesWith(witnessTable); - lookupInst->removeAndDeallocate(); - } - } - break; - case kIROp_lookup_interface_method: - { - // We have a `lookup_interface_method` instruction, - // so let's see whether it is a lookup in a known - // witness table. - IRLookupWitnessMethod* lookupInst = (IRLookupWitnessMethod*) ii; - - // We only want to deal with the case where the witness-table - // argument points to a concrete global table. - auto witnessTableArg = lookupInst->witnessTable.get(); - if(witnessTableArg->op != kIROp_witness_table) - continue; - IRWitnessTable* witnessTable = (IRWitnessTable*)witnessTableArg; - - // We also need to be sure that the requirement we - // are trying to look up is identified via a decl-ref: - auto requirementArg = lookupInst->requirementDeclRef.get(); - if(requirementArg->op != kIROp_decl_ref) - continue; - auto requirementDeclRef = ((IRDeclRef*) requirementArg)->declRef; - - // Use the witness table to look up the value that - // satisfies the requirement. - auto satisfyingVal = findWitnessVal(witnessTable, requirementDeclRef); - // We expect to always find something, but lets just - // be careful here. - if(!satisfyingVal) - continue; - - // If we get through all of the above checks, then we - // have a (more) concrete method that implements the interface, - // and so we should dispatch to that directly, rather than - // use the `lookup_interface_method` instruction. - lookupInst->replaceUsesWith(satisfyingVal); - lookupInst->removeAndDeallocate(); - } - break; - } + next = inst->getNextInst(); + switch(inst->op) + { + default: + break; - // We only care about `specialize` instructions. - if(ii->op != kIROp_specialize) - continue; - + case kIROp_GlobalGenericParam: + case kIROp_BindGlobalGenericParam: + // A "bind" instruction should have no uses in the + // first place, and all the global generic parameters + // should have had their uses replaced. + SLANG_ASSERT(!inst->firstUse); + inst->removeAndDeallocate(); + break; } } } - // Once the work list has gone dry, we should have the invariant - // that there are no `specialize` instructions inside of non-generic - // functions that in turn reference a generic function. - } - - RefPtr<GlobalGenericParamSubstitution> createGlobalGenericParamSubstitution( - EntryPointRequest * entryPointRequest, - ProgramLayout * programLayout, - IRSpecContext* context) - { - RefPtr<GlobalGenericParamSubstitution> globalParamSubst; - GlobalGenericParamSubstitution * curTailSubst = nullptr; - - // Because we can't currently put `specialize` instructions inside - // witness tables, or at the global scope, we will track a set of - // witness tables that we need to clone, and then specialize - // from the original module(s) to get what we need. + // Our goal here is to find `specialize` instructions that + // can be replaced with references to, e.g., a suitably + // specialized function, and to resolve any `lookup_interface_method` + // instructions to the concrete value fetched from a witness + // table. + // + // We need to be careful of a few things: + // + // * It would not in general make sense to consider specialize-able + // instructions under an `IRGeneric`, since that could mean "specialziing" + // code to parameter values that are still unknown. + // + // * We *also* need to be careful not to specialize something when one + // or more of its inputs is also a `specialize` or `lookup_interface_method` + // instruction, because then we'd be propagating through non-concrete + // values. + // + // The approach we use here is to build a work list of instructions + // that *can* become fully specialized, but aren't yet. Any + // instruction on the work list will be considered to be "unspecialized" + // and any instruction not on the work list is considered specialized. + // + // We will start by recursively walking all the instructions to add + // the appropriate ones to our work list: + // + addToSpecializationWorkListRec(sharedContext, moduleInst); - struct WitnessTableCloneWorkItem + // Now we are going to repeatedly walk our work list, and filter + // it to create a new work list. + List<IRInst*> workListCopy; + for(;;) { - IRWitnessTable* dstTable; - IRWitnessTable* originalTable; - }; - List<WitnessTableCloneWorkItem> witnessTablesToClone; + // Swap out the work list on the context so we can + // process it here without worrying about concurrent + // modifications. + workListCopy.Clear(); + workListCopy.SwapWith(sharedContext->workList); - struct WitnessTableSpecializationWorkItem - { - IRWitnessTable* dstTable; - IRWitnessTable* srcTable; - DeclRef<Decl> specDeclRef; - }; - List<WitnessTableSpecializationWorkItem> witnessTablesToSpecailize; - - Dictionary<Name*, IRWitnessTable*> witnessTablesByName; - auto namePool = entryPointRequest->compileRequest->getNamePool(); - - for (auto param : programLayout->globalGenericParams) - { - auto paramSubst = new GlobalGenericParamSubstitution(); - if (!globalParamSubst) - globalParamSubst = paramSubst; - if (curTailSubst) - curTailSubst->outer = paramSubst; - curTailSubst = paramSubst; - paramSubst->paramDecl = param->decl; - SLANG_ASSERT((UInt)param->index < entryPointRequest->genericParameterTypes.Count()); - paramSubst->actualType = entryPointRequest->genericParameterTypes[param->index]; - // find witness tables - for (auto witness : entryPointRequest->genericParameterWitnesses) + if(workListCopy.Count() == 0) + break; + + for(auto inst : workListCopy) { - if (auto subtypeWitness = witness.As<SubtypeWitness>()) + // We need to check whether it is possible to specialize + // the instruction yet (it might not be because its + // operands haven't been specialized) + if(!canSpecializeInst(sharedContext, inst)) { - if (subtypeWitness->sub->EqualsVal(paramSubst->actualType)) - { - auto witnessTableName = namePool->getName(getMangledNameForConformanceWitness(subtypeWitness->sub, subtypeWitness->sup)); - auto findWitnessTableByName = [&](Name* name) -> IRWitnessTable* - { - RefPtr<IRSpecSymbol> symbol; - if (!context->getSymbols().TryGetValue(name, symbol)) - return nullptr; - - return (IRWitnessTable*) symbol->irGlobalValue; - }; - - auto findCloneOfWitnessTableByName = [&](Name* name) -> IRWitnessTable* - { - IRWitnessTable* clonedTable = nullptr; - if (witnessTablesByName.TryGetValue(name, clonedTable)) - return clonedTable; - - IRWitnessTable* originalTable = findWitnessTableByName(name); - if (!originalTable) - return nullptr; - - clonedTable = context->builder->createWitnessTable(); - - WitnessTableCloneWorkItem cloneWorkItem; - cloneWorkItem.originalTable = originalTable; - cloneWorkItem.dstTable = clonedTable; - witnessTablesToClone.Add(cloneWorkItem); - - return clonedTable; - }; - - // First look for a non-generic witness table that matches - auto table = findCloneOfWitnessTableByName(witnessTableName); - if (!table) - { - // If we didn't find a non-generic table, then maybe we are looking at - // a specialization of a generic witness table. - if (auto subDeclRefType = subtypeWitness->sub.As<DeclRefType>()) - { - auto defaultSubst = createDefaultSubstitutions(entryPointRequest->compileRequest->mSession, subDeclRefType->declRef.getDecl()); - auto genericWitnessTableName = namePool->getName( - getMangledNameForConformanceWitness(DeclRef<Decl>(subDeclRefType->declRef.getDecl(), defaultSubst), subtypeWitness->sup)); - - IRWitnessTable* genericTable = findCloneOfWitnessTableByName(genericWitnessTableName); - SLANG_ASSERT(genericTable); - - WitnessTableSpecializationWorkItem specializeWorkItem; - specializeWorkItem.srcTable = genericTable; - specializeWorkItem.dstTable = context->builder->createWitnessTable(); - specializeWorkItem.dstTable->mangledName = context->getModule()->session->getNameObj(getMangledNameForConformanceWitness(subDeclRefType->declRef, subtypeWitness->sup)); - specializeWorkItem.specDeclRef = subDeclRefType->declRef; - - witnessTablesToSpecailize.Add(specializeWorkItem); - table = specializeWorkItem.dstTable; - } - } - // We expect to find the table no matter what. - SLANG_ASSERT(table); + // Put it back on the fresh work list, so that + // we can re-consider it in another iteration. + sharedContext->workList.Add(inst); + } + else + { + // Okay, perform any specialization step on this + // instruction that makes sense (which might be + // doing nothing). + specializeGenericsForInst(sharedContext, inst); - IRProxyVal * tableVal = new IRProxyVal(); - tableVal->inst.init(nullptr, table); - paramSubst->witnessTables.Add(KeyValuePair<RefPtr<Type>, RefPtr<Val>>(subtypeWitness->sup, tableVal)); - } + // Remove the instruction from consideration. + sharedContext->workListSet.Remove(inst); } } } - for (auto workItem : witnessTablesToClone) - { - cloneWitnessTableWithoutRegistering( - context, - workItem.originalTable, - workItem.dstTable); - } - - for (auto workItem : witnessTablesToSpecailize) - { - int diff = 0; - specializeWitnessTable( - context->shared, - context, - workItem.srcTable, - workItem.specDeclRef.SubstituteImpl(SubstitutionSet(nullptr, nullptr, globalParamSubst), &diff), - workItem.dstTable); - } + // Once the work list has gone dry, we should have the invariant + // that there are no `specialize` instructions inside of non-generic + // functions that in turn reference a generic function, *except* + // in the case where that generic is for a builtin function, in + // which case we wouldn't want to specialize it anyway. + } - return globalParamSubst; + void applyGlobalGenericParamSubstitution( + IRSpecContext* /*context*/) + { + // TODO: we need to figure out how to apply this } - + void markConstExpr( - Session* session, - IRInst* irValue) + IRBuilder* builder, + IRInst* irValue) { // We will take an IR value with type `T`, // and turn it into one with type `@ConstExpr T`. @@ -6418,6 +6317,9 @@ namespace Slang // TODO: need to be careful if the value already has a rate // qualifier set. - irValue->type = session->getConstExprType(irValue->getDataType()); + irValue->setFullType( + builder->getRateQualifiedType( + builder->getConstExprRate(), + irValue->getDataType())); } } |
