summaryrefslogtreecommitdiffstats
path: root/source/slang/lower-to-ir.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/lower-to-ir.cpp')
-rw-r--r--source/slang/lower-to-ir.cpp552
1 files changed, 484 insertions, 68 deletions
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index bee3edb16..e00dffa1d 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -8,51 +8,74 @@
namespace Slang
{
-struct SharedIRGenContext
-{
- EntryPointRequest* entryPoint;
- ProgramLayout* programLayout;
- CodeGenTarget target;
-};
-
-struct LoweredExprInfo
+struct LoweredValInfo
{
enum class Flavor
{
- Value,
+ None,
+ Simple,
};
- static LoweredExprInfo createValue(IRValue* value)
+ union
{
- LoweredExprInfo result;
- result.flavor = Flavor::Value;
- result.value = value;
- return result;
+ IRValue* val;
+ };
+ Flavor flavor;
+
+ LoweredValInfo()
+ {
+ flavor = Flavor::None;
+ val = nullptr;
}
- Flavor flavor;
- union
+ static LoweredValInfo simple(IRValue* v)
{
- IRValue* value;
- };
+ LoweredValInfo info;
+ info.flavor = Flavor::Simple;
+ info.val = v;
+ return info;
+ }
};
+struct SharedIRGenContext
+{
+ EntryPointRequest* entryPoint;
+ ProgramLayout* programLayout;
+ CodeGenTarget target;
+
+ Dictionary<DeclRef<Decl>, LoweredValInfo> declValues;
+};
+
+
struct IRGenContext
{
- Dictionary<Decl*, LoweredExprInfo> declValues;
+ SharedIRGenContext* shared;
IRBuilder* irBuilder;
};
-struct LoweredValInfo
+IRValue* getSimpleVal(LoweredValInfo lowered)
{
-};
+ switch(lowered.flavor)
+ {
+ case LoweredValInfo::Flavor::None:
+ return nullptr;
+
+ case LoweredValInfo::Flavor::Simple:
+ return lowered.val;
+
+ default:
+ SLANG_UNEXPECTED("unhandled value flavor");
+ return nullptr;
+ }
+}
struct LoweredTypeInfo
{
enum class Flavor
{
- Type,
+ None,
+ Simple,
};
union
@@ -60,8 +83,47 @@ struct LoweredTypeInfo
IRType* type;
};
Flavor flavor;
+
+ LoweredTypeInfo()
+ {
+ flavor = Flavor::None;
+ }
+
+ LoweredTypeInfo(IRType* t)
+ {
+ flavor = Flavor::Simple;
+ type = t;
+ }
};
+IRType* getSimpleType(LoweredTypeInfo lowered)
+{
+ switch(lowered.flavor)
+ {
+ case LoweredTypeInfo::Flavor::None:
+ return nullptr;
+
+ case LoweredTypeInfo::Flavor::Simple:
+ return lowered.type;
+
+ default:
+ SLANG_UNEXPECTED("unhandled value flavor");
+ return nullptr;
+ }
+}
+
+LoweredValInfo lowerVal(
+ IRGenContext* context,
+ Val* val);
+
+IRValue* lowerSimpleVal(
+ IRGenContext* context,
+ Val* val)
+{
+ auto lowered = lowerVal(context, val);
+ return getSimpleVal(lowered);
+}
+
LoweredTypeInfo lowerType(
IRGenContext* context,
Type* type);
@@ -73,28 +135,105 @@ static LoweredTypeInfo lowerType(
return lowerType(context, type.type);
}
-LoweredExprInfo lowerExpr(
+// Lower a type and expect the result to be simple
+IRType* lowerSimpleType(
+ IRGenContext* context,
+ Type* type)
+{
+ auto lowered = lowerType(context, type);
+ return getSimpleType(lowered);
+}
+
+IRType* lowerSimpleType(
+ IRGenContext* context,
+ QualType const& type)
+{
+ auto lowered = lowerType(context, type);
+ return getSimpleType(lowered);
+}
+
+
+LoweredValInfo lowerExpr(
IRGenContext* context,
Expr* expr);
+void lowerStmt(
+ IRGenContext* context,
+ Stmt* stmt);
+
+LoweredValInfo ensureDecl(
+ IRGenContext* context,
+ DeclRef<Decl> const& declRef);
+
//
struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredTypeInfo>
{
IRGenContext* context;
+ IRBuilder* getBuilder() { return context->irBuilder; }
+
LoweredValInfo visitVal(Val* val)
{
SLANG_UNIMPLEMENTED_X("value lowering");
}
+ LoweredValInfo visitConstantIntVal(ConstantIntVal* val)
+ {
+ // TODO: it is a bit messy here that the `ConstantIntVal` representation
+ // has no notion of a *type* associated with the value...
+
+ auto type = getBuilder()->getBaseType(BaseType::Int);
+ return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value));
+ }
+
LoweredTypeInfo visitType(Type* type)
{
SLANG_UNIMPLEMENTED_X("type lowering");
}
+ LoweredTypeInfo visitDeclRefType(DeclRefType* type)
+ {
+ // Catch-all for user-defined type references
+ LoweredValInfo loweredDeclRef = ensureDecl(context, type->declRef);
+
+ // TODO: make sure that the value is actually a type...
+
+ switch (loweredDeclRef.flavor)
+ {
+ case LoweredValInfo::Flavor::Simple:
+ return LoweredTypeInfo((IRType*)loweredDeclRef.val);
+
+ default:
+ SLANG_UNIMPLEMENTED_X("type lowering");
+ }
+
+ }
+
+ LoweredTypeInfo visitBasicExpressionType(BasicExpressionType* type)
+ {
+ return getBuilder()->getBaseType(type->BaseType);
+ }
+
+ LoweredTypeInfo visitVectorExpressionType(VectorExpressionType* type)
+ {
+ auto irElementType = lowerSimpleType(context, type->elementType);
+ auto irElementCount = lowerSimpleVal(context, type->elementCount);
+
+ return getBuilder()->getVectorType(irElementType, irElementCount);
+ }
+
};
+LoweredValInfo lowerVal(
+ IRGenContext* context,
+ Val* val)
+{
+ ValLoweringVisitor visitor;
+ visitor.context = context;
+ return visitor.dispatch(val);
+}
+
LoweredTypeInfo lowerType(
IRGenContext* context,
Type* type)
@@ -115,52 +254,40 @@ struct LoweringVisitor
//
-struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredExprInfo>
+struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
{
IRGenContext* context;
- LoweredExprInfo visitVarExpr(VarExpr* expr)
- {
- LoweredExprInfo info;
- if(context->declValues.TryGetValue(expr->declRef.getDecl(), info))
- return info;
-
- throw 99;
+ IRBuilder* getBuilder() { return context->irBuilder; }
- return LoweredExprInfo();
+ LoweredValInfo visitVarExpr(VarExpr* expr)
+ {
+ LoweredValInfo info = ensureDecl(context, expr->declRef);
+ return info;
}
- LoweredExprInfo visitOverloadedExpr(OverloadedExpr* expr)
+ LoweredValInfo visitOverloadedExpr(OverloadedExpr* expr)
{
SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST");
}
- LoweredExprInfo visitInitializerListExpr(InitializerListExpr* expr)
+ LoweredValInfo visitInitializerListExpr(InitializerListExpr* expr)
{
SLANG_UNIMPLEMENTED_X("codegen for initializer list expression");
}
- IRType* getIRType(LoweredTypeInfo const& typeInfo)
+ LoweredValInfo visitConstantExpr(ConstantExpr* expr)
{
- switch( typeInfo.flavor )
- {
- case LoweredTypeInfo::Flavor::Type:
- return typeInfo.type;
- }
- }
-
- LoweredExprInfo visitConstantExpr(ConstantExpr* expr)
- {
- auto type = getIRType(lowerType(context, expr->type));
+ auto type = lowerSimpleType(context, expr->type);
switch( expr->ConstType )
{
case ConstantExpr::ConstantType::Bool:
- return LoweredExprInfo::createValue(context->irBuilder->getBoolValue(expr->integerValue != 0));
+ return LoweredValInfo::simple(context->irBuilder->getBoolValue(expr->integerValue != 0));
case ConstantExpr::ConstantType::Int:
- return LoweredExprInfo::createValue(context->irBuilder->getIntValue(type, expr->integerValue));
+ return LoweredValInfo::simple(context->irBuilder->getIntValue(type, expr->integerValue));
case ConstantExpr::ConstantType::Float:
- return LoweredExprInfo::createValue(context->irBuilder->getFloatValue(type, expr->floatingPointValue));
+ return LoweredValInfo::simple(context->irBuilder->getFloatValue(type, expr->floatingPointValue));
case ConstantExpr::ConstantType::String:
break;
}
@@ -168,68 +295,198 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredExprInfo>
SLANG_UNEXPECTED("unexpected constant type");
}
- LoweredExprInfo visitAggTypeCtorExpr(AggTypeCtorExpr* expr)
+ LoweredValInfo visitAggTypeCtorExpr(AggTypeCtorExpr* expr)
{
SLANG_UNIMPLEMENTED_X("codegen for aggregate type constructor expression");
}
- LoweredExprInfo visitInvokeExpr(InvokeExpr* expr)
+ void addArgs(List<IRValue*>* ioArgs, LoweredValInfo argInfo)
+ {
+ auto& args = *ioArgs;
+ switch( argInfo.flavor )
+ {
+ case LoweredValInfo::Flavor::Simple:
+ args.Add(getSimpleVal(argInfo));
+ break;
+
+ default:
+ SLANG_UNIMPLEMENTED_X("addArgs case");
+ break;
+ }
+ }
+
+ LoweredValInfo lowerIntrinsicCall(
+ InvokeExpr* expr,
+ IntrinsicOp intrinsicOp)
+ {
+ auto type = lowerSimpleType(context, expr->type);
+
+ List<IRValue*> irArgs;
+ for( auto arg : expr->Arguments )
+ {
+ auto loweredArg = lowerExpr(context, arg);
+ addArgs(&irArgs, loweredArg);
+ }
+
+ UInt argCount = irArgs.Count();
+
+ return LoweredValInfo::simple(getBuilder()->emitIntrinsicInst(type, intrinsicOp, argCount, &irArgs[0]));
+ }
+
+ LoweredValInfo lowerSimpleCall(InvokeExpr* expr)
{
+ auto loweredFunc = lowerExpr(context, expr->FunctionExpr);
+
+ for( auto arg : expr->Arguments )
+ {
+ auto loweredArg = lowerExpr(context, arg);
+ }
+
SLANG_UNIMPLEMENTED_X("codegen for invoke expression");
}
- LoweredExprInfo visitIndexExpr(IndexExpr* expr)
+ LoweredValInfo visitInvokeExpr(InvokeExpr* expr)
+ {
+ // TODO: need to detect calls to builtins here, so that we can expand
+ // them as their own special opcodes...
+
+ auto funcExpr = expr->FunctionExpr;
+ if( auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>() )
+ {
+ auto funcDeclRef = funcDeclRefExpr->declRef;
+ auto funcDecl = funcDeclRef.getDecl();
+ if(auto intrinsicOpModifier = funcDecl->FindModifier<IntrinsicOpModifier>())
+ {
+ return lowerIntrinsicCall(expr, intrinsicOpModifier->op);
+ //
+ }
+ // TODO: handle target intrinsic modifier too...
+
+ if( auto ctorDeclRef = funcDeclRef.As<ConstructorDecl>() )
+ {
+ // HACK: we know all constructors are builtins for now,
+ // so we need to emit them as a call to the corresponding
+ // builtin operation.
+
+ auto type = lowerSimpleType(context, expr->type);
+
+ List<IRValue*> irArgs;
+ for( auto arg : expr->Arguments )
+ {
+ auto loweredArg = lowerExpr(context, arg);
+ addArgs(&irArgs, loweredArg);
+ }
+
+ UInt argCount = irArgs.Count();
+
+ return LoweredValInfo::simple(getBuilder()->emitConstructorInst(type, argCount, &irArgs[0]));
+ }
+ }
+
+ return lowerSimpleCall(expr);
+ }
+
+ LoweredValInfo visitIndexExpr(IndexExpr* expr)
{
SLANG_UNIMPLEMENTED_X("codegen for subscript expression");
}
- LoweredExprInfo visitMemberExpr(MemberExpr* expr)
+ LoweredValInfo extractField(
+ LoweredTypeInfo fieldType,
+ LoweredValInfo base,
+ UInt fieldIndex)
+ {
+ switch (base.flavor)
+ {
+ case LoweredValInfo::Flavor::Simple:
+ {
+ IRValue* irBase = base.val;
+ return LoweredValInfo::simple(
+ getBuilder()->createFieldExtract(
+ getSimpleType(fieldType),
+ irBase,
+ fieldIndex));
+ }
+ break;
+
+ default:
+ SLANG_UNIMPLEMENTED_X("codegen for field extract");
+ }
+ }
+
+ LoweredValInfo visitMemberExpr(MemberExpr* expr)
{
+ auto loweredType = lowerType(context, expr->type);
+ auto loweredBase = lowerExpr(context, expr->BaseExpression);
+
+ auto declRef = expr->declRef;
+ if (auto fieldDeclRef = declRef.As<StructField>())
+ {
+ // Okay, easy enough: we have a reference to a field of a struct type...
+
+ // HACK: for now just scan the decl to find the right index.
+ // TODO: we need to deal with the fact that the struct might get
+ // tuple-ified.
+ //
+ UInt index = 0;
+ for (auto fieldDecl : getMembersOfType<StructField>(fieldDeclRef.GetParent().As<AggTypeDecl>()))
+ {
+ if (fieldDecl == fieldDeclRef.getDecl())
+ {
+ break;
+ }
+
+ index++;
+ }
+
+ return extractField(loweredType, loweredBase, index);
+ }
+
SLANG_UNIMPLEMENTED_X("codegen for subscript expression");
}
- LoweredExprInfo visitSwizzleExpr(SwizzleExpr* expr)
+ LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr)
{
SLANG_UNIMPLEMENTED_X("codegen for swizzle expression");
}
- LoweredExprInfo visitDerefExpr(DerefExpr* expr)
+ LoweredValInfo visitDerefExpr(DerefExpr* expr)
{
SLANG_UNIMPLEMENTED_X("codegen for deref expression");
}
- LoweredExprInfo visitTypeCastExpr(TypeCastExpr* expr)
+ LoweredValInfo visitTypeCastExpr(TypeCastExpr* expr)
{
SLANG_UNIMPLEMENTED_X("codegen for type cast expression");
}
- LoweredExprInfo visitSelectExpr(SelectExpr* expr)
+ LoweredValInfo visitSelectExpr(SelectExpr* expr)
{
SLANG_UNIMPLEMENTED_X("codegen for select expression");
}
- LoweredExprInfo visitGenericAppExpr(GenericAppExpr* expr)
+ LoweredValInfo visitGenericAppExpr(GenericAppExpr* expr)
{
SLANG_UNIMPLEMENTED_X("generic application expression during code generation");
}
- LoweredExprInfo visitSharedTypeExpr(SharedTypeExpr* expr)
+ LoweredValInfo visitSharedTypeExpr(SharedTypeExpr* expr)
{
SLANG_UNIMPLEMENTED_X("shared type expression during code generation");
}
- LoweredExprInfo visitAssignExpr(AssignExpr* expr)
+ LoweredValInfo visitAssignExpr(AssignExpr* expr)
{
SLANG_UNIMPLEMENTED_X("shared type expression during code generation");
}
- LoweredExprInfo visitParenExpr(ParenExpr* expr)
+ LoweredValInfo visitParenExpr(ParenExpr* expr)
{
return lowerExpr(context, expr->base);
}
};
-LoweredExprInfo lowerExpr(
+LoweredValInfo lowerExpr(
IRGenContext* context,
Expr* expr)
{
@@ -238,25 +495,141 @@ LoweredExprInfo lowerExpr(
return visitor.dispatch(expr);
}
-struct LoweredDeclInfo
-{};
+struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
+{
+ IRGenContext* context;
+
+ IRBuilder* getBuilder() { return context->irBuilder; }
-struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredDeclInfo>
+ void visitStmt(Stmt* stmt)
+ {
+ SLANG_UNIMPLEMENTED_X("stmt catch-all");
+ }
+
+ void visitBlockStmt(BlockStmt* stmt)
+ {
+ lowerStmt(context, stmt->body);
+ }
+
+ void visitReturnStmt(ReturnStmt* stmt)
+ {
+ if( auto expr = stmt->Expression )
+ {
+ auto loweredExpr = lowerExpr(context, expr);
+
+ getBuilder()->createReturn(getSimpleVal(loweredExpr));
+ }
+ else
+ {
+ getBuilder()->createReturn();
+ }
+ }
+};
+
+void lowerStmt(
+ IRGenContext* context,
+ Stmt* stmt)
+{
+ StmtLoweringVisitor visitor;
+ visitor.context = context;
+ return visitor.dispatch(stmt);
+}
+
+struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
IRGenContext* context;
- LoweredDeclInfo visitDeclBase(DeclBase* decl)
+ IRBuilder* getBuilder()
+ {
+ return context->irBuilder;
+ }
+
+ LoweredValInfo visitDeclBase(DeclBase* decl)
{
SLANG_UNIMPLEMENTED_X("decl catch-all");
}
- LoweredDeclInfo visitDecl(Decl* decl)
+ LoweredValInfo visitDecl(Decl* decl)
{
SLANG_UNIMPLEMENTED_X("decl catch-all");
}
+
+ LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl)
+ {
+ // User-defined aggregate type: need to translate into
+ // a corresponding IR aggregate type.
+
+ List<LoweredTypeInfo> fieldTypes;
+ List<IRType*> irFieldTypes;
+
+ for (auto fieldDecl : decl->GetFields())
+ {
+ // TODO: need to be prepared to deal with tuple-ness of fields here
+ auto fieldType = lowerType(context, fieldDecl->getType());
+
+ fieldTypes.Add(fieldType);
+
+ switch (fieldType.flavor)
+ {
+ case LoweredTypeInfo::Flavor::Simple:
+ irFieldTypes.Add(fieldType.type);
+ break;
+
+ default:
+ SLANG_UNIMPLEMENTED_X("struct field type");
+ }
+ }
+
+ // TODO: need to track relationship to original fields...
+
+ IRType* irStructType = getBuilder()->getStructType(
+ irFieldTypes.Count(),
+ &irFieldTypes[0]);
+
+ return LoweredValInfo::simple(irStructType);
+ }
+
+ LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl)
+ {
+ IRBuilder subBuilderStorage = *getBuilder();
+ IRBuilder* subBuilder = &subBuilderStorage;
+
+ // need to create an IR function here
+
+ IRFunc* irFunc = subBuilder->createFunc();
+ subBuilder->parentInst = irFunc;
+
+ IRBlock* entryBlock = subBuilder->createBlock();
+ subBuilder->parentInst = entryBlock;
+
+ IRGenContext subContextStorage = *context;
+ IRGenContext* subContext = &subContextStorage;
+ subContext->irBuilder = subBuilder;
+
+ // set up sub context for generating our new function
+
+ for( auto paramDecl : decl->GetParameters() )
+ {
+ IRType* irParamType = lowerSimpleType(context, paramDecl->getType());
+ IRParam* irParam = subBuilder->createParam(irParamType);
+
+ DeclRef<ParamDecl> paramDeclRef = makeDeclRef(paramDecl.Ptr());
+
+ LoweredValInfo irParamVal = LoweredValInfo::simple(irParam);
+
+ subContext->shared->declValues.Add(paramDeclRef, irParamVal);
+ }
+
+ auto irResultType = lowerType(context, decl->ReturnType);
+
+
+ lowerStmt(subContext, decl->Body);
+
+ return LoweredValInfo::simple(irFunc);
+ }
};
-LoweredDeclInfo lowerDecl(
+LoweredValInfo lowerDecl(
IRGenContext* context,
Decl* decl)
{
@@ -265,6 +638,30 @@ LoweredDeclInfo lowerDecl(
return visitor.dispatch(decl);
}
+LoweredValInfo ensureDecl(
+ IRGenContext* context,
+ DeclRef<Decl> const& declRef)
+{
+ auto shared = context->shared;
+
+ LoweredValInfo result;
+ if(shared->declValues.TryGetValue(declRef, result))
+ return result;
+
+ // TODO: this is where we need to apply any specializations
+ // from the declaration reference, so that they can be
+ // applied correctly to the declaration itself...
+
+ IRGenContext subContext = *context;
+
+ result = lowerDecl(context, declRef.getDecl());
+
+ shared->declValues[declRef] = result;
+
+ return result;
+}
+
+
EntryPointLayout* findEntryPointLayout(
SharedIRGenContext* shared,
EntryPointRequest* entryPointRequest)
@@ -299,7 +696,7 @@ static void lowerEntryPointToIR(
lowerDecl(context, entryPointFunc);
}
-void lowerEntryPointToIR(
+IRModule* lowerEntryPointToIR(
EntryPointRequest* entryPoint,
ProgramLayout* programLayout,
CodeGenTarget target)
@@ -307,13 +704,32 @@ void lowerEntryPointToIR(
SharedIRGenContext sharedContextStorage;
SharedIRGenContext* sharedContext = &sharedContextStorage;
+ sharedContext->entryPoint = entryPoint;
+ sharedContext->programLayout = programLayout;
+ sharedContext->target = target;
+
IRGenContext contextStorage;
IRGenContext* context = &contextStorage;
+ context->shared = sharedContext;
+
+ IRBuilder builderStorage;
+ IRBuilder* builder = &builderStorage;
+ builder->module = nullptr;
+ builder->parentInst = nullptr;
+
+ IRModule* module = builder->createModule();
+ builder->module = module;
+ builder->parentInst = module;
+
+ context->irBuilder = builder;
+
auto entryPointLayout = findEntryPointLayout(sharedContext, entryPoint);
lowerEntryPointToIR(context, entryPoint, entryPointLayout);
+ return module;
+
}
}