diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 8 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 60 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-loop-inversion.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 255 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 71 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 30 | ||||
| -rw-r--r-- | source/slang/slang-legalize-types.cpp | 19 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 168 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 2 |
15 files changed, 640 insertions, 46 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 695423285..6c51ccef0 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -870,6 +870,14 @@ bool operator!=(__none_t noneVal, Optional<T> val) return val.hasValue; } +__generic<each T> +__magic_type(TupleType) +struct Tuple +{ + __intrinsic_op($(0)) + __init(expand each T); +} + __generic<T> __magic_type(NativeRefType) __intrinsic_type($(kIROp_NativePtrType)) diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index faf15470f..3a2b2933d 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -525,7 +525,19 @@ FuncType* ASTBuilder::getFuncType(ArrayView<Type*> parameters, Type* result, Typ TupleType* ASTBuilder::getTupleType(List<Type*>& types) { - return getOrCreate<TupleType>(types.getArrayView()); + // The canonical form of a tuple type is always a DeclRefType(GenAppDeclRef(TupleDecl, ConcreteTypePack(types...))). + // If `types` is already a single ConcreteTypePack, then we can use that directly. + if (types.getCount() == 1) + { + if (isTypePack(types[0])) + { + return as<TupleType>(getSpecializedBuiltinType(types[0], "TupleType")); + } + } + + // Otherwise, we need to create a ConcreteTypePack to hold the types. + auto typePack = getTypePack(types.getArrayView()); + return as<TupleType>(getSpecializedBuiltinType(typePack, "TupleType")); } TypeType* ASTBuilder::getTypeType(Type* type) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 39de083f0..a4225d041 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -101,6 +101,7 @@ INST(Nop, nop, 0, 0) // /* Kind */ INST(TypeKind, Type, 0, HOISTABLE) + INST(TypeParameterPackKind, TypeParameterPack, 0, HOISTABLE) INST(RateKind, Rate, 0, HOISTABLE) INST(GenericKind, Generic, 0, HOISTABLE) INST_RANGE(Kind, TypeKind, GenericKind) @@ -244,6 +245,7 @@ INST(RTTIType, rtti_type, 0, HOISTABLE) INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE) INST(TupleType, tuple_type, 0, HOISTABLE) INST(TargetTupleType, TargetTuple, 0, HOISTABLE) +INST(ExpandTypeOrVal, ExpandTypeOrVal, 1, HOISTABLE) // A type that identifies it's contained type as being emittable as `spirv_literal. INST(SPIRVLiteralType, spirvLiteralType, 1, HOISTABLE) @@ -343,6 +345,9 @@ INST(MakeTuple, makeTuple, 0, 0) INST(MakeTargetTuple, makeTuple, 0, 0) INST(GetTargetTupleElement, getTargetTupleElement, 0, 0) INST(GetTupleElement, getTupleElement, 2, 0) +INST(MakeWitnessPack, MakeWitnessPack, 0, HOISTABLE) +INST(Expand, Expand, 1, 0) +INST(Each, Each, 1, HOISTABLE) INST(MakeResultValue, makeResultValue, 1, 0) INST(MakeResultError, makeResultError, 1, 0) INST(IsResultError, isResultError, 1, 0) @@ -566,6 +571,7 @@ INST(SwizzledStore, swizzledStore, 2, 0) /* IRTerminatorInst */ INST(Return, return_val, 1, 0) + INST(Yield, yield, 1, 0) /* IRUnconditionalBranch */ // unconditionalBranch <target> INST(unconditionalBranch, unconditionalBranch, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index e30b903b5..dc5fb2744 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2484,6 +2484,13 @@ struct IRReturn : IRTerminatorInst IRInst* getVal() { return getOperand(0); } }; +struct IRYield : IRTerminatorInst +{ + IR_LEAF_ISA(Yield); + + IRInst* getVal() { return getOperand(0); } +}; + struct IRDiscard : IRTerminatorInst {}; @@ -2825,12 +2832,36 @@ struct IRBindGlobalGenericParam : IRInst IR_LEAF_ISA(BindGlobalGenericParam) }; +struct IRExpand : IRInst +{ + IR_LEAF_ISA(Expand) + UInt getCaptureCount() { return getOperandCount(); } + IRInst* getCapture(UInt index) { return getOperand(index); } + IRInstList<IRBlock> getBlocks() + { + return IRInstList<IRBlock>(getChildren()); + } +}; + + +struct IREach : IRInst +{ + IR_LEAF_ISA(Each) + + IRInst* getElement() { return getOperand(0); } +}; + // An Instruction that creates a tuple value. struct IRMakeTuple : IRInst { IR_LEAF_ISA(MakeTuple) }; +struct IRMakeWitnessPack : IRInst +{ + IR_LEAF_ISA(MakeWitnessPack) +}; + struct IRGetTupleElement : IRInst { IR_LEAF_ISA(GetTupleElement) @@ -3328,7 +3359,7 @@ public: // Get the current function (or other value with code) // that we are inserting into (if any). - IRGlobalValueWithCode* getFunc() { return m_insertLoc.getFunc(); } + IRInst* getFunc() { return m_insertLoc.getFunc(); } void setInsertInto(IRInst* insertInto) { setInsertLoc(IRInsertLoc::atEnd(insertInto)); } void setInsertBefore(IRInst* insertBefore) { setInsertLoc(IRInsertLoc::before(insertBefore)); } @@ -3478,6 +3509,8 @@ public: IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2); IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2, IRType* type3); + IRExpandType* getExpandTypeOrVal(IRType* type, IRInst* pattern, ArrayView<IRInst*> capture); + IRResultType* getResultType(IRType* valueType, IRType* errorType); IROptionalType* getOptionalType(IRType* valueType); @@ -3485,6 +3518,7 @@ public: IRWitnessTableType* getWitnessTableType(IRType* baseType); IRWitnessTableIDType* getWitnessTableIDType(IRType* baseType); IRType* getTypeType() { return getType(IROp::kIROp_TypeType); } + IRType* getTypeParameterPackKind() { return getType(IROp::kIROp_TypeParameterPackKind); } IRType* getKeyType() { return nullptr; } IRTypeKind* getTypeKind(); @@ -3715,6 +3749,9 @@ public: return emitSpecializeInst(type, genericVal, args.getCount(), args.begin()); } + IRInst* emitExpandInst(IRType* type, UInt capturedArgCount, IRInst* const* capturedArgs); + IRInst* emitEachInst(IRType* type, IRInst* base, IRInst* indexArg = nullptr); + IRInst* emitLookupInterfaceMethodInst( IRType* type, IRInst* witnessTableVal, @@ -3814,6 +3851,13 @@ public: IRInst* emitMakeTuple(IRType* type, List<IRInst*> const& args) { + if (args.getCount() == 1) + { + if (args[0]->getOp() == kIROp_Expand) + { + return args[0]; + } + } return emitMakeTuple(type, args.getCount(), args.getBuffer()); } @@ -3828,11 +3872,22 @@ public: return emitMakeTuple(SLANG_COUNT_OF(args), args); } + IRInst* emitMakeWitnessPack(IRType* type, ArrayView<IRInst*> args) + { + return emitIntrinsicInst(type, kIROp_MakeWitnessPack, (UInt)args.getCount(), args.getBuffer()); + } + IRInst* emitMakeString(IRInst* nativeStr); IRInst* emitGetNativeString(IRInst* str); + IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, int element) + { + return emitGetTupleElement(type, tuple, (UInt)element); + } + IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, UInt element); + IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element); IRInst* emitMakeResultError(IRType* resultType, IRInst* errorVal); IRInst* emitMakeResultValue(IRType* resultType, IRInst* val); @@ -4186,6 +4241,9 @@ public: IRInst* emitReturn( IRInst* val); + IRInst* emitYield( + IRInst* val); + IRInst* emitReturn(); IRInst* emitThrow(IRInst* val); diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 5bb485b22..8b08b9045 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -1170,6 +1170,17 @@ IRFunc* cloneFuncImpl( return clonedFunc; } +// Can an inst with `opcode` contain basic blocks as children? +bool canInstContainBasicBlocks(IROp opcode) +{ + switch (opcode) + { + case kIROp_Expand: + return true; + default: + return false; + } +} IRInst* cloneInst( IRSpecContextBase* context, @@ -1238,7 +1249,10 @@ IRInst* cloneInst( argCount, newArgs.getArrayView().getBuffer()); builder->addInst(clonedInst); registerClonedValue(context, clonedInst, originalValues); - cloneDecorationsAndChildren(context, clonedInst, originalInst); + if (canInstContainBasicBlocks(clonedInst->getOp())) + cloneGlobalValueWithCodeCommon(context, (IRGlobalValueWithCode*)clonedInst, (IRGlobalValueWithCode*)originalInst, originalValues); + else + cloneDecorationsAndChildren(context, clonedInst, originalInst); cloneExtraDecorations(context, clonedInst, originalValues); return clonedInst; } diff --git a/source/slang/slang-ir-loop-inversion.cpp b/source/slang/slang-ir-loop-inversion.cpp index c811bf357..9ae734507 100644 --- a/source/slang/slang-ir-loop-inversion.cpp +++ b/source/slang/slang-ir-loop-inversion.cpp @@ -140,7 +140,7 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) builder.setInsertInto(loop->getParent()); const auto s = as<IRBlock>(loop->getParent()); - auto domTree = computeDominatorTree(s->getParent()); + auto domTree = computeDominatorTree((IRGlobalValueWithCode*)s->getParent()); SLANG_ASSERT(s); const auto c1 = loop->getTargetBlock(); const auto c1Terminator = as<IRIfElse>(c1->getTerminator()); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 232633d69..aa8dfddab 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -324,7 +324,10 @@ struct PeepholeContext : InstPassBase } break; case kIROp_GetTupleElement: - if (inst->getOperand(0)->getOp() == kIROp_MakeTuple) + switch (inst->getOperand(0)->getOp()) + { + case kIROp_MakeTuple: + case kIROp_MakeWitnessPack: { auto element = inst->getOperand(1); if (auto intLit = as<IRIntLit>(element)) @@ -333,6 +336,10 @@ struct PeepholeContext : InstPassBase maybeRemoveOldInst(inst); changed = true; } + break; + } + default: + break; } break; case kIROp_FieldExtract: @@ -1181,6 +1188,15 @@ bool peepholeOptimize(TargetProgram* target, IRInst* func) return context.processFunc(func); } +bool peepholeOptimizeInst(TargetProgram* target, IRModule* module, IRInst* inst) +{ + PeepholeContext context = PeepholeContext(module); + context.targetProgram = target; + context.useFastAnalysis = true; + context.processInst(inst); + return context.changed; +} + bool peepholeOptimizeGlobalScope(TargetProgram* target, IRModule* module) { PeepholeContext context = PeepholeContext(module); diff --git a/source/slang/slang-ir-peephole.h b/source/slang/slang-ir-peephole.h index 411267072..3fdb74450 100644 --- a/source/slang/slang-ir-peephole.h +++ b/source/slang/slang-ir-peephole.h @@ -22,6 +22,7 @@ namespace Slang /// Apply peephole optimizations. bool peepholeOptimize(TargetProgram* target, IRModule* module, PeepholeOptimizationOptions options); bool peepholeOptimize(TargetProgram* target, IRInst* func); + bool peepholeOptimizeInst(TargetProgram* target, IRModule* module, IRInst* inst); bool peepholeOptimizeGlobalScope(TargetProgram* target, IRModule* module); bool tryReplaceInstUsesWithSimplifiedValue(TargetProgram* target, IRModule* module, IRInst* inst); } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index c86906b2d..2eb16112f 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1,6 +1,6 @@ // slang-ir-specialize.cpp #include "slang-ir-specialize.h" - +#include "slang-ir-peephole.h" #include "slang-ir.h" #include "slang-ir-clone.h" #include "slang-ir-insts.h" @@ -585,6 +585,15 @@ struct SpecializationContext case kIROp_BindExistentialsType: return maybeSpecializeBindExistentialsType(as<IRBindExistentialsType>(inst)); + + case kIROp_Expand: + return maybeSpecializeExpand(as<IRExpand>(inst)); + + case kIROp_ExpandTypeOrVal: + return maybeSpecializeExpandTypeOrVal(as<IRExpandType>(inst)); + + case kIROp_GetTupleElement: + return maybeSpecializeFoldableInst(inst); } } @@ -597,7 +606,7 @@ struct SpecializationContext { // Note: While we currently have named the instruction // `lookup_witness_method`, the `method` part is a misnomer - // and the same instruction can look up *any* interface + // and the same instruction can look up *any* interfacemay // requirement based on the witness table that provides // a conformance, and the "key" that indicates the interface // requirement. @@ -609,7 +618,9 @@ struct SpecializationContext // auto witnessTable = as<IRWitnessTable>(lookupInst->getWitnessTable()); if (!witnessTable) + { return false; + } // Because we have a concrete witness table, we can // use it to look up the IR value that satisfies @@ -642,6 +653,19 @@ struct SpecializationContext return true; } + bool maybeSpecializeFoldableInst(IRInst* inst) + { + auto firstUse = inst->firstUse; + bool instChanged = peepholeOptimizeInst(targetProgram, module, inst); + + for (auto use = firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + addToWorkList(user); + } + return instChanged; + } + // The above subroutine needed a way to look up // the satisfying value for a given requirement // key in a concrete witness table, so let's @@ -2208,6 +2232,233 @@ struct SpecializationContext return false; } + IRInst* specializeExpandChildInst(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* childInst, UInt index) + { + IRCloneEnv freshEnv; + IRCloneEnv* subEnv = &cloneEnv; + switch (childInst->getOp()) + { + case kIROp_Expand: + { + subEnv = &freshEnv; + break; + } + } + auto type = clonePatternVal(*subEnv, builder, childInst->getFullType(), index); + for (UInt i = 0; i < childInst->getOperandCount(); i++) + { + clonePatternVal(*subEnv, builder, childInst->getOperand(i), index); + } + auto newInst = cloneInst(subEnv, builder, childInst); + newInst = builder->replaceOperand(&newInst->typeUse, type); + subEnv->mapOldValToNew[childInst] = newInst; + IRBuilder subBuilder(*builder); + subBuilder.setInsertInto(newInst); + for (auto child : childInst->getChildren()) + { + specializeExpandChildInst(*subEnv, &subBuilder, child, index); + } + return newInst; + } + + bool maybeSpecializeExpand(IRExpand* expandInst) + { + if (expandInst->getCaptureCount() == 0) + return false; + + for (UInt i = 0; i < expandInst->getCaptureCount(); i++) + { + if (!as<IRTupleType>(expandInst->getCapture(i))) + return false; + } + + IRBuilder builder(expandInst); + builder.setInsertBefore(expandInst); + List<IRInst*> elements; + UInt elementCount = 0; + if (auto firstTupleType = as<IRTupleType>(expandInst->getCapture(0))) + { + elementCount = firstTupleType->getOperandCount(); + } + if (elementCount == 0) + { + auto resultTuple = builder.emitMakeTuple(0, (IRInst*const*)nullptr); + expandInst->replaceUsesWith(resultTuple); + expandInst->removeAndDeallocate(); + addUsersToWorkList(resultTuple); + return true; + } + + for (UInt i = 0; i < elementCount; i++) + { + IRCloneEnv cloneEnv; + IRBlock* firstBlock = nullptr; + IRBuilder subBuilder = builder; + for (auto childBlock : expandInst->getBlocks()) + { + auto newBlock = subBuilder.emitBlock(); + if (!firstBlock) + firstBlock = newBlock; + cloneEnv.mapOldValToNew[childBlock] = newBlock; + } + auto indexParam = expandInst->getFirstBlock()->getFirstParam(); + SLANG_ASSERT(indexParam); + cloneEnv.mapOldValToNew[indexParam] = subBuilder.getIntValue(subBuilder.getIntType(), i); + + builder.emitBranch(firstBlock); + + IRBlock* mergeBlock = subBuilder.emitBlock(); + builder.setInsertInto(mergeBlock); + + for (auto childBlock : expandInst->getBlocks()) + { + auto newBlock = cloneEnv.mapOldValToNew[childBlock]; + subBuilder.setInsertInto(newBlock); + for (auto child : childBlock->getChildren()) + { + if (as<IRYield>(child)) + { + elements.add(cloneEnv.mapOldValToNew[child->getOperand(0)]); + subBuilder.emitBranch(mergeBlock); + continue; + } + specializeExpandChildInst(cloneEnv, &subBuilder, child, i); + addToWorkList(childBlock); + } + } + + } + auto resultTuple = builder.emitMakeTuple(elements); + auto currentBlock = builder.getBlock(); + for (auto nextInst = expandInst->next; nextInst;) + { + auto next = nextInst->next; + nextInst->insertAtEnd(currentBlock); + nextInst = next; + } + addUsersToWorkList(expandInst); + expandInst->replaceUsesWith(resultTuple); + expandInst->removeAndDeallocate(); + return true; + } + + IRInst* clonePatternValImpl(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack) + { + if (!val) + return val; + + switch (val->getOp()) + { + case kIROp_ExpandTypeOrVal: + return val; + case kIROp_Each: + { + auto eachInst = as<IREach>(val); + auto packInst = eachInst->getElement(); + if (auto tuple = as<IRTupleType>(packInst)) + { + SLANG_RELEASE_ASSERT(indexInPack < tuple->getOperandCount()); + return tuple->getOperand(indexInPack); + } + else if (auto makeTuple = as<IRMakeTuple>(packInst)) + { + SLANG_RELEASE_ASSERT(indexInPack < makeTuple->getOperandCount()); + return makeTuple->getOperand(indexInPack); + } + else if (!as<IRTypeKind>(packInst->getDataType())) + { + auto type = clonePatternVal(cloneEnv, builder, val, indexInPack); + return builder->emitGetTupleElement((IRType*)type, packInst, indexInPack); + } + return val; + } + default: + break; + } + bool anyChange = false; + ShortList<IRInst*> operands; + for (UInt i = 0; i < val->getOperandCount(); i++) + { + auto newOperand = clonePatternVal(cloneEnv, builder, val->getOperand(i), indexInPack); + if (newOperand != val->getOperand(i)) + anyChange = true; + operands.add(newOperand); + } + auto newType = clonePatternVal(cloneEnv, builder, val->getFullType(), indexInPack); + if (newType != val->getFullType()) + anyChange = true; + if (!anyChange) + return val; + + auto newVal = builder->emitIntrinsicInst((IRType*)newType, val->getOp(), operands.getCount(), operands.getArrayView().getBuffer()); + if (newVal != val) + { + cloneInstDecorationsAndChildren(&cloneEnv, module, val, newVal); + } + return newVal; + } + + IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack) + { + if (auto clonedVal = cloneEnv.mapOldValToNew.tryGetValue(val)) + return *clonedVal; + cloneEnv.mapOldValToNew[val] = val; + auto result = clonePatternValImpl(cloneEnv, builder, val, indexInPack); + cloneEnv.mapOldValToNew[val] = result; + return result; + } + + bool maybeSpecializeExpandTypeOrVal(IRExpandType* expandInst) + { + if (expandInst->getCaptureCount() == 0) + return false; + + bool anyAbstractPack = false; + for (UInt i = 0; i < expandInst->getCaptureCount(); i++) + { + if (!as<IRTupleType>(expandInst->getCaptureType(i))) + { + anyAbstractPack = true; + break; + } + } + if (anyAbstractPack) + return false; + IRBuilder builder(expandInst); + builder.setInsertBefore(expandInst); + List<IRInst*> elements; + UInt elementCount = 0; + if (auto firstTupleType = as<IRTupleType>(expandInst->getCaptureType(0))) + { + elementCount = firstTupleType->getOperandCount(); + } + for (UInt i = 0; i < elementCount; i++) + { + IRCloneEnv cloneEnv; + auto element = clonePatternVal(cloneEnv, &builder, expandInst->getPatternType(), i); + elements.add(element); + } + addUsersToWorkList(expandInst); + if (as<IRWitnessTableType>(expandInst->getDataType())) + { + List<IRType*> types; + for (auto element : elements) + types.add(element->getDataType()); + auto newTupleType = builder.getTupleType(types); + auto result = builder.emitMakeWitnessPack(newTupleType, elements.getArrayView()); + expandInst->replaceUsesWith(result); + expandInst->removeAndDeallocate(); + return true; + } + else + { + auto newTupleType = builder.getTupleType(elements.getCount(), (IRType*const*)elements.getBuffer()); + expandInst->replaceUsesWith(newTupleType); + expandInst->removeAndDeallocate(); + return true; + } + } + // The handling of specialization for global generic type // parameters involves searching for all `bind_global_generic_param` // instructions in the input module. diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 788c9a391..e44c4079b 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -146,7 +146,8 @@ bool allUsesLeadToLoads(IRInst* inst) // Is the given variable one that we can promote to SSA form? bool isPromotableVar( ConstructSSAContext* /*context*/, - IRVar* var) + IRVar* var, + HashSet<IRBlock*> &knownBlocks) { // We want to identify variables such that we can always // determine what they will contain at a point in the @@ -226,8 +227,13 @@ bool isPromotableVar( } break; } + + // If the use is outside of known blocks, then we can't promote it. + if (!knownBlocks.contains(getBlock(user))) + return false; } + // If all of the uses passed our checking, then // we are good to go. return true; @@ -237,6 +243,12 @@ bool isPromotableVar( void identifyPromotableVars( ConstructSSAContext* context) { + HashSet<IRBlock*> knownBlocks; + for (auto bb = context->globalVal->getFirstBlock(); bb; bb = bb->getNextBlock()) + { + knownBlocks.add(bb); + } + for (auto bb = context->globalVal->getFirstBlock(); bb; bb = bb->getNextBlock()) { for (auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst()) @@ -246,7 +258,7 @@ void identifyPromotableVars( IRVar* var = (IRVar*)ii; - if (isPromotableVar(context, var)) + if (isPromotableVar(context, var, knownBlocks)) { context->promotableVars.add(var); } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index bfd6c20cf..c97c04f88 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -808,12 +808,7 @@ namespace Slang IRType* IRFunc::getResultType() { return getDataType()->getResultType(); } UInt IRFunc::getParamCount() { return getDataType()->getParamCount(); } IRType* IRFunc::getParamType(UInt index) { return getDataType()->getParamType(index); } - - void IRGlobalValueWithCode::addBlock(IRBlock* block) - { - block->insertAtEnd(this); - } - + void fixUpFuncType(IRFunc* func, IRType* resultType) { SLANG_ASSERT(func); @@ -1279,14 +1274,16 @@ namespace Slang // Get the current function (or other value with code) // that we are inserting into (if any). - IRGlobalValueWithCode* IRInsertLoc::getFunc() const + IRInst* IRInsertLoc::getFunc() const { auto pp = getParent(); if (const auto block = as<IRBlock>(pp)) { pp = pp->getParent(); } - return as<IRGlobalValueWithCode>(pp); + if (as<IRGlobalValueWithCode>(pp) || as<IRExpand>(pp)) + return pp; + return nullptr; } void addHoistableInst( @@ -2805,6 +2802,14 @@ namespace Slang return getTupleType(SLANG_COUNT_OF(operands), operands); } + IRExpandType* IRBuilder::getExpandTypeOrVal(IRType* type, IRInst* pattern, ArrayView<IRInst*> capture) + { + ShortList<IRInst*> args; + args.add(pattern); + args.addRange(capture); + return (IRExpandType*)emitIntrinsicInst(type, kIROp_ExpandTypeOrVal, args.getCount(), args.getArrayView().getBuffer()); + } + IRResultType* IRBuilder::getResultType(IRType* valueType, IRType* errorType) { IRInst* operands[] = {valueType, errorType}; @@ -3548,6 +3553,26 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitExpandInst(IRType* type, UInt capturedArgCount, IRInst* const* capturedArgs) + { + auto inst = createInstWithTrailingArgs<IRSpecialize>( + this, + kIROp_Expand, + type, + capturedArgCount, + capturedArgs, + 0, + nullptr); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitEachInst(IRType* type, IRInst* base, IRInst* indexArg) + { + IRInst* args[] = { base, indexArg }; + return emitIntrinsicInst(type, kIROp_Each, indexArg ? 2 : 1, args); + } + IRInst* IRBuilder::emitLookupInterfaceMethodInst( IRType* type, IRInst* witnessTableVal, @@ -4057,6 +4082,12 @@ namespace Slang return emitIntrinsicInst(getNativeStringType(), kIROp_getNativeStr, 1, &str); } + IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element) + { + IRInst* args[] = { tuple, element }; + return emitIntrinsicInst(type, kIROp_GetTupleElement, 2, args); + } + IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, UInt element) { // As a quick simplification/optimization, if the user requests @@ -4070,9 +4101,7 @@ namespace Slang return makeTuple->getOperand(element); } } - - IRInst* args[] = { tuple, getIntValue(getIntType(), element) }; - return emitIntrinsicInst(type, kIROp_GetTupleElement, 2, args); + return emitGetTupleElement(type, tuple, getIntValue(getIntType(), element)); } IRInst* IRBuilder::emitMakeResultError(IRType* resultType, IRInst* errorVal) @@ -5409,6 +5438,18 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitYield( + IRInst* val) + { + auto inst = createInst<IRYield>( + this, + kIROp_Yield, + nullptr, + val); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitReturn() { auto voidVal = getVoidValue(); @@ -7238,6 +7279,7 @@ namespace Slang case kIROp_Func: case kIROp_GlobalVar: case kIROp_Generic: + case kIROp_Expand: dumpIRGlobalValueWithCode(context, (IRGlobalValueWithCode*)inst); return; @@ -8230,6 +8272,7 @@ namespace Slang case kIROp_WitnessTableEntry: case kIROp_InterfaceRequirementEntry: case kIROp_Block: + case kIROp_Each: return false; /// Liveness markers have no side effects @@ -8250,6 +8293,7 @@ namespace Slang case kIROp_MakeMatrixFromScalar: case kIROp_MatrixReshape: case kIROp_VectorReshape: + case kIROp_MakeWitnessPack: case kIROp_MakeArray: case kIROp_MakeArrayFromElement: case kIROp_MakeStruct: @@ -8806,6 +8850,11 @@ namespace Slang } } + void IRInst::addBlock(IRBlock* block) + { + block->insertAtEnd(this); + } + void IRInst::dump() { if (auto intLit = as<IRIntLit>(this)) diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 719b383c3..ececdad43 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -534,7 +534,7 @@ public: /// This searches up the parent chain starting with `getParent()` looking for a code-bearing /// value that things are being inserted into (could be a function, generic, etc.) /// - IRGlobalValueWithCode* getFunc() const; + IRInst* getFunc() const; private: /// Internal constructor @@ -567,6 +567,8 @@ enum class IRTypeLayoutRuleName _Count, }; +struct IRBlock; + // Every value in the IR is an instruction (even things // like literal values). // @@ -833,6 +835,13 @@ struct IRInst /// Print the IR to stdout for debugging purposes /// void dump(); + + /// Insert a basic block at the end of this func/code containing inst. + void addBlock(IRBlock* block); + + IRBlock* getFirstBlock() { return (IRBlock*)getFirstChild(); } + IRBlock* getLastBlock() { return (IRBlock*)getLastChild(); } + }; enum class IRDynamicCastBehavior @@ -1291,11 +1300,6 @@ struct IRBlock : IRInst getLastOrdinaryInst()); } - // The parent of a basic block is assumed to be a - // value with code (e.g., a function, global variable - // with initializer, etc.). - IRGlobalValueWithCode* getParent() { return cast<IRGlobalValueWithCode>(IRInst::getParent()); } - // The predecessor and successor lists of a block are needed // when we want to work with the control flow graph (CFG) of // a function. Rather than store these explicitly (and thus @@ -1620,6 +1624,7 @@ struct IRRateQualifiedType : IRType // same type. SIMPLE_IR_PARENT_TYPE(Kind, Type); SIMPLE_IR_TYPE(TypeKind, Kind); +SIMPLE_IR_TYPE(TypeParameterPackKind, Kind); // The kind of any and all generics. // @@ -1941,6 +1946,16 @@ struct IRTargetTupleType : IRType IR_LEAF_ISA(TargetTupleType) }; +/// Represents a `expand T` type used in variadic generic decls in Slang. Expected to be substituted +/// by actual types during specialization. +struct IRExpandType : IRType +{ + IR_LEAF_ISA(ExpandTypeOrVal) + IRType* getPatternType() { return (IRType*)(getOperand(0)); } + UInt getCaptureCount() { return getOperandCount() - 1; } + IRType* getCaptureType(UInt index) { return (IRType*)(getOperand(index + 1)); } +}; + /// Represents an `Result<T,E>`, used by functions that throws error codes. struct IRResultType : IRType { @@ -2040,9 +2055,6 @@ struct IRGlobalValueWithCode : IRInst return IRInstList<IRBlock>(getChildren()); } - // Add a block to the end of this function. - void addBlock(IRBlock* block); - IR_PARENT_ISA(GlobalValueWithCode) }; diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp index 5e8390cef..6009ef33e 100644 --- a/source/slang/slang-legalize-types.cpp +++ b/source/slang/slang-legalize-types.cpp @@ -287,6 +287,11 @@ struct TupleTypeBuilder { specialType = legalFieldType; } + + // `void` is currently legalized to simple, but we don't want to add a + // `void` field to the struct. + if (legalLeafType.getSimple()->getOp() == kIROp_VoidType) + return; } break; @@ -419,7 +424,6 @@ struct TupleTypeBuilder bool isSpecialField = context->isSpecialType(fieldType); auto legalFieldType = legalizeType(context, fieldType); - addField( field->getKey(), legalFieldType, @@ -1385,10 +1389,15 @@ LegalType legalizeTypeImpl( context, arrayType->getElementType()); - // If element type hasn't change, return original type. - if (legalElementType.flavor == LegalType::Flavor::simple && - legalElementType.getSimple() == arrayType->getElementType()) - return LegalType::simple(arrayType); + if (legalElementType.flavor == LegalType::Flavor::simple) + { + if (legalElementType.getSimple()->getOp() == kIROp_VoidType) + return LegalType(); + + // If element type hasn't change, return original type. + if (legalElementType.getSimple() == arrayType->getElementType()) + return LegalType::simple(arrayType); + } ArrayLegalTypeWrapper wrapper; wrapper.arrayType = arrayType; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 95e9d96da..9ceb3074a 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -594,6 +594,9 @@ struct IRGenContext bool includeDebugInfo = false; + // The element index if we are inside an `expand` expression. + IRInst* expandIndex = nullptr; + explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder) : shared(inShared) , astBuilder(inAstBuilder) @@ -1653,6 +1656,86 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(resultVal); } + LoweredValInfo visitConcreteTypePack(ConcreteTypePack* typePack) + { + ShortList<IRType*> types; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto loweredType = lowerType(context, typePack->getElementType(i)); + types.add(loweredType); + } + auto irBuilder = getBuilder(); + IRType* irTypePack = irBuilder->getTupleType((UInt)types.getCount(), types.getArrayView().getBuffer()); + return LoweredValInfo::simple(irTypePack); + } + + LoweredValInfo visitEachType(EachType* eachType) + { + auto type = lowerType(context, eachType->getElementType()); + return LoweredValInfo::simple(getBuilder()->emitEachInst( + getBuilder()->getTypeKind(), + type)); + } + + LoweredValInfo visitExpandType(ExpandType* expandType) + { + auto irBuilder = getBuilder(); + auto type = lowerType(context, expandType->getPatternType()); + ShortList<IRInst*> capturedTypes; + for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++) + { + auto loweredType = lowerType(context, expandType->getCapturedTypePack(i)); + capturedTypes.add(loweredType); + } + return LoweredValInfo::simple(irBuilder->getExpandTypeOrVal( + irBuilder->getTypeKind(), type, capturedTypes.getArrayView().arrayView)); + } + + LoweredValInfo visitTypePackSubtypeWitness(TypePackSubtypeWitness* witnessPack) + { + auto irBuilder = getBuilder(); + ShortList<IRInst*> witnesses; + ShortList<IRType*> elementTypes; + for (Index i = 0; i < witnessPack->getCount(); i++) + { + auto loweredWitness = lowerVal(context, witnessPack->getWitness(i)); + witnesses.add(loweredWitness.val); + elementTypes.add(loweredWitness.val->getFullType()); + } + auto irWitnessPack = irBuilder->emitMakeWitnessPack( + irBuilder->getTupleType((UInt)elementTypes.getCount(), elementTypes.getArrayView().getBuffer()), + witnesses.getArrayView().arrayView); + return LoweredValInfo::simple(irWitnessPack); + } + + LoweredValInfo visitExpandSubtypeWitness(ExpandSubtypeWitness* witness) + { + auto irBuilder = getBuilder(); + + auto patternWitnessVal = lowerVal(context, witness->getPatternTypeWitness()); + auto subType = lowerType(context, witness->getSub()); + auto supType = lowerType(context, witness->getSup()); + auto witnessTableType = irBuilder->getWitnessTableType(supType); + ShortList<IRInst*> captures; + if (auto expandType = as<IRExpandType>(subType)) + { + for (UInt i = 0; i < expandType->getCaptureCount(); i++) + { + captures.add(expandType->getCaptureType(i)); + } + } + return LoweredValInfo::simple(irBuilder->getExpandTypeOrVal(witnessTableType, patternWitnessVal.val, captures.getArrayView().arrayView)); + } + + LoweredValInfo visitEachSubtypeWitness(EachSubtypeWitness* witness) + { + auto elementWitness = lowerVal(context, witness->getPatternTypeWitness()); + auto irBuilder = getBuilder(); + auto subType = lowerType(context, witness->getSub()); + auto witnessTableType = irBuilder->getWitnessTableType(subType); + return LoweredValInfo::simple(irBuilder->emitEachInst(witnessTableType, getSimpleVal(context, elementWitness))); + } + LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val) { if (as<ThisTypeConstraintDecl>(val->getDeclRef())) @@ -1885,6 +1968,23 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower context->irBuilder->getTypeKind())); } + IRType* visitTupleType(TupleType* type) + { + List<IRType*> elementTypes; + if (as<ConcreteTypePack>(type->getTypePack())) + { + for (Index i = 0; i < type->getMemberCount(); i++) + { + elementTypes.add(lowerType(context, type->getMember(i))); + } + return context->irBuilder->getTupleType(elementTypes); + } + else + { + return lowerType(context, type->getTypePack()); + } + } + IRType* visitNamedExpressionType(NamedExpressionType* type) { return (IRType*)getSimpleVal(context, dispatchType(type->getCanonicalType())); @@ -4315,19 +4415,54 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> return lowerSubExpr(expr->base); } - LoweredValInfo visitPackExpr(PackExpr*) + LoweredValInfo visitPackExpr(PackExpr* expr) { - SLANG_UNIMPLEMENTED_X("codegen for pack expression"); + List<IRInst*> irArgs; + for (auto arg : expr->args) + { + irArgs.add(getSimpleVal(context, lowerSubExpr(arg))); + } + auto irMakeTuple = getBuilder()->emitMakeTuple(irArgs); + return LoweredValInfo::simple(irMakeTuple); } - LoweredValInfo visitEachExpr(EachExpr*) + LoweredValInfo visitEachExpr(EachExpr* expr) { - SLANG_UNIMPLEMENTED_X("codegen for each expression"); + auto subVal = lowerSubExpr(expr->baseExpr); + SLANG_ASSERT(context->expandIndex); + auto irEach = getBuilder()->emitGetTupleElement(lowerType(context, expr->type), getSimpleVal(context, subVal), context->expandIndex); + return LoweredValInfo::simple(irEach); } - LoweredValInfo visitExpandExpr(ExpandExpr*) + LoweredValInfo visitExpandExpr(ExpandExpr* expr) { - SLANG_UNIMPLEMENTED_X("codegen for expand expression"); + auto irBuilder = getBuilder(); + auto irType = lowerType(context, expr->type); + List<IRInst*> irCapturedPacks; + if (auto expandType = as<IRExpandType>(irType)) + { + for (UInt i = 0; i < expandType->getCaptureCount(); i++) + { + irCapturedPacks.add(expandType->getCaptureType(i)); + } + } + else + { + // If the type of the expression is not an ExpandType, then it must be + // a DeclRefType to a generic type pack parameter. + // In this case, the captured type is just the DeclRefType itself. + irCapturedPacks.add(irType); + } + auto expandInst = irBuilder->emitExpandInst(irType, (UInt)irCapturedPacks.getCount(), irCapturedPacks.getBuffer()); + irBuilder->setInsertInto(expandInst); + irBuilder->emitBlock(); + auto eachIndex = irBuilder->emitParam(irBuilder->getIntType()); + IRInst* oldExpandIndex = context->expandIndex; + context->expandIndex = eachIndex; + SLANG_DEFER(context->expandIndex = oldExpandIndex); + irBuilder->emitYield(getSimpleVal(context, lowerSubExpr(expr->baseExpr))); + irBuilder->setInsertAfter(expandInst); + return LoweredValInfo::simple(expandInst); } LoweredValInfo getSimpleDefaultVal(IRType* type) @@ -8968,11 +9103,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // in the order they were declared. for (auto member : genericDecl->members) { - if (auto typeParamDecl = as<GenericTypeParamDecl>(member)) + if (auto typeParamDecl = as<GenericTypeParamDeclBase>(member)) { - // TODO: use a `TypeKind` to represent the - // classifier of the parameter. - auto param = subBuilder->emitParam(subBuilder->getTypeType()); + IRType* typeKind = nullptr; + if (as<GenericTypePackParamDecl>(member)) + typeKind = subBuilder->getTypeParameterPackKind(); + else + typeKind = subBuilder->getTypeType(); + auto param = subBuilder->emitParam(typeKind); addNameHint(context, param, typeParamDecl); subContext->setValue(typeParamDecl, LoweredValInfo::simple(param)); } @@ -10289,7 +10427,15 @@ LoweredValInfo ensureDecl( } IRBuilder subIRBuilder(context->irBuilder->getModule()); - subIRBuilder.setInsertInto(subIRBuilder.getModule()); + if (as<VarDecl>(decl) && decl->findModifier<LocalTempVarModifier>()) + { + // Do not modify insert location. + subIRBuilder.setInsertLoc(context->irBuilder->getInsertLoc()); + } + else + { + subIRBuilder.setInsertInto(subIRBuilder.getModule()); + } IRGenEnv subEnv; subEnv.outer = context->env; diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 7fa9e8fc0..a55f0eb1a 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -507,7 +507,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // TODO: need to figure out how to unify this with the logic // in the generic case... Type* DeclRefType::create( - ASTBuilder* astBuilder, + ASTBuilder* astBuilder, DeclRef<Decl> declRef) { if (declRef.getDecl()->findModifier<BuiltinTypeModifier>()) |
