summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp168
1 files changed, 157 insertions, 11 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 95e9d96da..9ceb3074a 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -594,6 +594,9 @@ struct IRGenContext
bool includeDebugInfo = false;
+ // The element index if we are inside an `expand` expression.
+ IRInst* expandIndex = nullptr;
+
explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder)
: shared(inShared)
, astBuilder(inAstBuilder)
@@ -1653,6 +1656,86 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(resultVal);
}
+ LoweredValInfo visitConcreteTypePack(ConcreteTypePack* typePack)
+ {
+ ShortList<IRType*> types;
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ {
+ auto loweredType = lowerType(context, typePack->getElementType(i));
+ types.add(loweredType);
+ }
+ auto irBuilder = getBuilder();
+ IRType* irTypePack = irBuilder->getTupleType((UInt)types.getCount(), types.getArrayView().getBuffer());
+ return LoweredValInfo::simple(irTypePack);
+ }
+
+ LoweredValInfo visitEachType(EachType* eachType)
+ {
+ auto type = lowerType(context, eachType->getElementType());
+ return LoweredValInfo::simple(getBuilder()->emitEachInst(
+ getBuilder()->getTypeKind(),
+ type));
+ }
+
+ LoweredValInfo visitExpandType(ExpandType* expandType)
+ {
+ auto irBuilder = getBuilder();
+ auto type = lowerType(context, expandType->getPatternType());
+ ShortList<IRInst*> capturedTypes;
+ for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++)
+ {
+ auto loweredType = lowerType(context, expandType->getCapturedTypePack(i));
+ capturedTypes.add(loweredType);
+ }
+ return LoweredValInfo::simple(irBuilder->getExpandTypeOrVal(
+ irBuilder->getTypeKind(), type, capturedTypes.getArrayView().arrayView));
+ }
+
+ LoweredValInfo visitTypePackSubtypeWitness(TypePackSubtypeWitness* witnessPack)
+ {
+ auto irBuilder = getBuilder();
+ ShortList<IRInst*> witnesses;
+ ShortList<IRType*> elementTypes;
+ for (Index i = 0; i < witnessPack->getCount(); i++)
+ {
+ auto loweredWitness = lowerVal(context, witnessPack->getWitness(i));
+ witnesses.add(loweredWitness.val);
+ elementTypes.add(loweredWitness.val->getFullType());
+ }
+ auto irWitnessPack = irBuilder->emitMakeWitnessPack(
+ irBuilder->getTupleType((UInt)elementTypes.getCount(), elementTypes.getArrayView().getBuffer()),
+ witnesses.getArrayView().arrayView);
+ return LoweredValInfo::simple(irWitnessPack);
+ }
+
+ LoweredValInfo visitExpandSubtypeWitness(ExpandSubtypeWitness* witness)
+ {
+ auto irBuilder = getBuilder();
+
+ auto patternWitnessVal = lowerVal(context, witness->getPatternTypeWitness());
+ auto subType = lowerType(context, witness->getSub());
+ auto supType = lowerType(context, witness->getSup());
+ auto witnessTableType = irBuilder->getWitnessTableType(supType);
+ ShortList<IRInst*> captures;
+ if (auto expandType = as<IRExpandType>(subType))
+ {
+ for (UInt i = 0; i < expandType->getCaptureCount(); i++)
+ {
+ captures.add(expandType->getCaptureType(i));
+ }
+ }
+ return LoweredValInfo::simple(irBuilder->getExpandTypeOrVal(witnessTableType, patternWitnessVal.val, captures.getArrayView().arrayView));
+ }
+
+ LoweredValInfo visitEachSubtypeWitness(EachSubtypeWitness* witness)
+ {
+ auto elementWitness = lowerVal(context, witness->getPatternTypeWitness());
+ auto irBuilder = getBuilder();
+ auto subType = lowerType(context, witness->getSub());
+ auto witnessTableType = irBuilder->getWitnessTableType(subType);
+ return LoweredValInfo::simple(irBuilder->emitEachInst(witnessTableType, getSimpleVal(context, elementWitness)));
+ }
+
LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val)
{
if (as<ThisTypeConstraintDecl>(val->getDeclRef()))
@@ -1885,6 +1968,23 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
context->irBuilder->getTypeKind()));
}
+ IRType* visitTupleType(TupleType* type)
+ {
+ List<IRType*> elementTypes;
+ if (as<ConcreteTypePack>(type->getTypePack()))
+ {
+ for (Index i = 0; i < type->getMemberCount(); i++)
+ {
+ elementTypes.add(lowerType(context, type->getMember(i)));
+ }
+ return context->irBuilder->getTupleType(elementTypes);
+ }
+ else
+ {
+ return lowerType(context, type->getTypePack());
+ }
+ }
+
IRType* visitNamedExpressionType(NamedExpressionType* type)
{
return (IRType*)getSimpleVal(context, dispatchType(type->getCanonicalType()));
@@ -4315,19 +4415,54 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
return lowerSubExpr(expr->base);
}
- LoweredValInfo visitPackExpr(PackExpr*)
+ LoweredValInfo visitPackExpr(PackExpr* expr)
{
- SLANG_UNIMPLEMENTED_X("codegen for pack expression");
+ List<IRInst*> irArgs;
+ for (auto arg : expr->args)
+ {
+ irArgs.add(getSimpleVal(context, lowerSubExpr(arg)));
+ }
+ auto irMakeTuple = getBuilder()->emitMakeTuple(irArgs);
+ return LoweredValInfo::simple(irMakeTuple);
}
- LoweredValInfo visitEachExpr(EachExpr*)
+ LoweredValInfo visitEachExpr(EachExpr* expr)
{
- SLANG_UNIMPLEMENTED_X("codegen for each expression");
+ auto subVal = lowerSubExpr(expr->baseExpr);
+ SLANG_ASSERT(context->expandIndex);
+ auto irEach = getBuilder()->emitGetTupleElement(lowerType(context, expr->type), getSimpleVal(context, subVal), context->expandIndex);
+ return LoweredValInfo::simple(irEach);
}
- LoweredValInfo visitExpandExpr(ExpandExpr*)
+ LoweredValInfo visitExpandExpr(ExpandExpr* expr)
{
- SLANG_UNIMPLEMENTED_X("codegen for expand expression");
+ auto irBuilder = getBuilder();
+ auto irType = lowerType(context, expr->type);
+ List<IRInst*> irCapturedPacks;
+ if (auto expandType = as<IRExpandType>(irType))
+ {
+ for (UInt i = 0; i < expandType->getCaptureCount(); i++)
+ {
+ irCapturedPacks.add(expandType->getCaptureType(i));
+ }
+ }
+ else
+ {
+ // If the type of the expression is not an ExpandType, then it must be
+ // a DeclRefType to a generic type pack parameter.
+ // In this case, the captured type is just the DeclRefType itself.
+ irCapturedPacks.add(irType);
+ }
+ auto expandInst = irBuilder->emitExpandInst(irType, (UInt)irCapturedPacks.getCount(), irCapturedPacks.getBuffer());
+ irBuilder->setInsertInto(expandInst);
+ irBuilder->emitBlock();
+ auto eachIndex = irBuilder->emitParam(irBuilder->getIntType());
+ IRInst* oldExpandIndex = context->expandIndex;
+ context->expandIndex = eachIndex;
+ SLANG_DEFER(context->expandIndex = oldExpandIndex);
+ irBuilder->emitYield(getSimpleVal(context, lowerSubExpr(expr->baseExpr)));
+ irBuilder->setInsertAfter(expandInst);
+ return LoweredValInfo::simple(expandInst);
}
LoweredValInfo getSimpleDefaultVal(IRType* type)
@@ -8968,11 +9103,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// in the order they were declared.
for (auto member : genericDecl->members)
{
- if (auto typeParamDecl = as<GenericTypeParamDecl>(member))
+ if (auto typeParamDecl = as<GenericTypeParamDeclBase>(member))
{
- // TODO: use a `TypeKind` to represent the
- // classifier of the parameter.
- auto param = subBuilder->emitParam(subBuilder->getTypeType());
+ IRType* typeKind = nullptr;
+ if (as<GenericTypePackParamDecl>(member))
+ typeKind = subBuilder->getTypeParameterPackKind();
+ else
+ typeKind = subBuilder->getTypeType();
+ auto param = subBuilder->emitParam(typeKind);
addNameHint(context, param, typeParamDecl);
subContext->setValue(typeParamDecl, LoweredValInfo::simple(param));
}
@@ -10289,7 +10427,15 @@ LoweredValInfo ensureDecl(
}
IRBuilder subIRBuilder(context->irBuilder->getModule());
- subIRBuilder.setInsertInto(subIRBuilder.getModule());
+ if (as<VarDecl>(decl) && decl->findModifier<LocalTempVarModifier>())
+ {
+ // Do not modify insert location.
+ subIRBuilder.setInsertLoc(context->irBuilder->getInsertLoc());
+ }
+ else
+ {
+ subIRBuilder.setInsertInto(subIRBuilder.getModule());
+ }
IRGenEnv subEnv;
subEnv.outer = context->env;