diff options
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 168 |
1 files changed, 157 insertions, 11 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 95e9d96da..9ceb3074a 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -594,6 +594,9 @@ struct IRGenContext bool includeDebugInfo = false; + // The element index if we are inside an `expand` expression. + IRInst* expandIndex = nullptr; + explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder) : shared(inShared) , astBuilder(inAstBuilder) @@ -1653,6 +1656,86 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(resultVal); } + LoweredValInfo visitConcreteTypePack(ConcreteTypePack* typePack) + { + ShortList<IRType*> types; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto loweredType = lowerType(context, typePack->getElementType(i)); + types.add(loweredType); + } + auto irBuilder = getBuilder(); + IRType* irTypePack = irBuilder->getTupleType((UInt)types.getCount(), types.getArrayView().getBuffer()); + return LoweredValInfo::simple(irTypePack); + } + + LoweredValInfo visitEachType(EachType* eachType) + { + auto type = lowerType(context, eachType->getElementType()); + return LoweredValInfo::simple(getBuilder()->emitEachInst( + getBuilder()->getTypeKind(), + type)); + } + + LoweredValInfo visitExpandType(ExpandType* expandType) + { + auto irBuilder = getBuilder(); + auto type = lowerType(context, expandType->getPatternType()); + ShortList<IRInst*> capturedTypes; + for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++) + { + auto loweredType = lowerType(context, expandType->getCapturedTypePack(i)); + capturedTypes.add(loweredType); + } + return LoweredValInfo::simple(irBuilder->getExpandTypeOrVal( + irBuilder->getTypeKind(), type, capturedTypes.getArrayView().arrayView)); + } + + LoweredValInfo visitTypePackSubtypeWitness(TypePackSubtypeWitness* witnessPack) + { + auto irBuilder = getBuilder(); + ShortList<IRInst*> witnesses; + ShortList<IRType*> elementTypes; + for (Index i = 0; i < witnessPack->getCount(); i++) + { + auto loweredWitness = lowerVal(context, witnessPack->getWitness(i)); + witnesses.add(loweredWitness.val); + elementTypes.add(loweredWitness.val->getFullType()); + } + auto irWitnessPack = irBuilder->emitMakeWitnessPack( + irBuilder->getTupleType((UInt)elementTypes.getCount(), elementTypes.getArrayView().getBuffer()), + witnesses.getArrayView().arrayView); + return LoweredValInfo::simple(irWitnessPack); + } + + LoweredValInfo visitExpandSubtypeWitness(ExpandSubtypeWitness* witness) + { + auto irBuilder = getBuilder(); + + auto patternWitnessVal = lowerVal(context, witness->getPatternTypeWitness()); + auto subType = lowerType(context, witness->getSub()); + auto supType = lowerType(context, witness->getSup()); + auto witnessTableType = irBuilder->getWitnessTableType(supType); + ShortList<IRInst*> captures; + if (auto expandType = as<IRExpandType>(subType)) + { + for (UInt i = 0; i < expandType->getCaptureCount(); i++) + { + captures.add(expandType->getCaptureType(i)); + } + } + return LoweredValInfo::simple(irBuilder->getExpandTypeOrVal(witnessTableType, patternWitnessVal.val, captures.getArrayView().arrayView)); + } + + LoweredValInfo visitEachSubtypeWitness(EachSubtypeWitness* witness) + { + auto elementWitness = lowerVal(context, witness->getPatternTypeWitness()); + auto irBuilder = getBuilder(); + auto subType = lowerType(context, witness->getSub()); + auto witnessTableType = irBuilder->getWitnessTableType(subType); + return LoweredValInfo::simple(irBuilder->emitEachInst(witnessTableType, getSimpleVal(context, elementWitness))); + } + LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val) { if (as<ThisTypeConstraintDecl>(val->getDeclRef())) @@ -1885,6 +1968,23 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower context->irBuilder->getTypeKind())); } + IRType* visitTupleType(TupleType* type) + { + List<IRType*> elementTypes; + if (as<ConcreteTypePack>(type->getTypePack())) + { + for (Index i = 0; i < type->getMemberCount(); i++) + { + elementTypes.add(lowerType(context, type->getMember(i))); + } + return context->irBuilder->getTupleType(elementTypes); + } + else + { + return lowerType(context, type->getTypePack()); + } + } + IRType* visitNamedExpressionType(NamedExpressionType* type) { return (IRType*)getSimpleVal(context, dispatchType(type->getCanonicalType())); @@ -4315,19 +4415,54 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> return lowerSubExpr(expr->base); } - LoweredValInfo visitPackExpr(PackExpr*) + LoweredValInfo visitPackExpr(PackExpr* expr) { - SLANG_UNIMPLEMENTED_X("codegen for pack expression"); + List<IRInst*> irArgs; + for (auto arg : expr->args) + { + irArgs.add(getSimpleVal(context, lowerSubExpr(arg))); + } + auto irMakeTuple = getBuilder()->emitMakeTuple(irArgs); + return LoweredValInfo::simple(irMakeTuple); } - LoweredValInfo visitEachExpr(EachExpr*) + LoweredValInfo visitEachExpr(EachExpr* expr) { - SLANG_UNIMPLEMENTED_X("codegen for each expression"); + auto subVal = lowerSubExpr(expr->baseExpr); + SLANG_ASSERT(context->expandIndex); + auto irEach = getBuilder()->emitGetTupleElement(lowerType(context, expr->type), getSimpleVal(context, subVal), context->expandIndex); + return LoweredValInfo::simple(irEach); } - LoweredValInfo visitExpandExpr(ExpandExpr*) + LoweredValInfo visitExpandExpr(ExpandExpr* expr) { - SLANG_UNIMPLEMENTED_X("codegen for expand expression"); + auto irBuilder = getBuilder(); + auto irType = lowerType(context, expr->type); + List<IRInst*> irCapturedPacks; + if (auto expandType = as<IRExpandType>(irType)) + { + for (UInt i = 0; i < expandType->getCaptureCount(); i++) + { + irCapturedPacks.add(expandType->getCaptureType(i)); + } + } + else + { + // If the type of the expression is not an ExpandType, then it must be + // a DeclRefType to a generic type pack parameter. + // In this case, the captured type is just the DeclRefType itself. + irCapturedPacks.add(irType); + } + auto expandInst = irBuilder->emitExpandInst(irType, (UInt)irCapturedPacks.getCount(), irCapturedPacks.getBuffer()); + irBuilder->setInsertInto(expandInst); + irBuilder->emitBlock(); + auto eachIndex = irBuilder->emitParam(irBuilder->getIntType()); + IRInst* oldExpandIndex = context->expandIndex; + context->expandIndex = eachIndex; + SLANG_DEFER(context->expandIndex = oldExpandIndex); + irBuilder->emitYield(getSimpleVal(context, lowerSubExpr(expr->baseExpr))); + irBuilder->setInsertAfter(expandInst); + return LoweredValInfo::simple(expandInst); } LoweredValInfo getSimpleDefaultVal(IRType* type) @@ -8968,11 +9103,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // in the order they were declared. for (auto member : genericDecl->members) { - if (auto typeParamDecl = as<GenericTypeParamDecl>(member)) + if (auto typeParamDecl = as<GenericTypeParamDeclBase>(member)) { - // TODO: use a `TypeKind` to represent the - // classifier of the parameter. - auto param = subBuilder->emitParam(subBuilder->getTypeType()); + IRType* typeKind = nullptr; + if (as<GenericTypePackParamDecl>(member)) + typeKind = subBuilder->getTypeParameterPackKind(); + else + typeKind = subBuilder->getTypeType(); + auto param = subBuilder->emitParam(typeKind); addNameHint(context, param, typeParamDecl); subContext->setValue(typeParamDecl, LoweredValInfo::simple(param)); } @@ -10289,7 +10427,15 @@ LoweredValInfo ensureDecl( } IRBuilder subIRBuilder(context->irBuilder->getModule()); - subIRBuilder.setInsertInto(subIRBuilder.getModule()); + if (as<VarDecl>(decl) && decl->findModifier<LocalTempVarModifier>()) + { + // Do not modify insert location. + subIRBuilder.setInsertLoc(context->irBuilder->getInsertLoc()); + } + else + { + subIRBuilder.setInsertInto(subIRBuilder.getModule()); + } IRGenEnv subEnv; subEnv.outer = context->env; |
