diff options
Diffstat (limited to 'source/slang/lower-to-ir.cpp')
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 552 |
1 files changed, 484 insertions, 68 deletions
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index bee3edb16..e00dffa1d 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -8,51 +8,74 @@ namespace Slang { -struct SharedIRGenContext -{ - EntryPointRequest* entryPoint; - ProgramLayout* programLayout; - CodeGenTarget target; -}; - -struct LoweredExprInfo +struct LoweredValInfo { enum class Flavor { - Value, + None, + Simple, }; - static LoweredExprInfo createValue(IRValue* value) + union { - LoweredExprInfo result; - result.flavor = Flavor::Value; - result.value = value; - return result; + IRValue* val; + }; + Flavor flavor; + + LoweredValInfo() + { + flavor = Flavor::None; + val = nullptr; } - Flavor flavor; - union + static LoweredValInfo simple(IRValue* v) { - IRValue* value; - }; + LoweredValInfo info; + info.flavor = Flavor::Simple; + info.val = v; + return info; + } }; +struct SharedIRGenContext +{ + EntryPointRequest* entryPoint; + ProgramLayout* programLayout; + CodeGenTarget target; + + Dictionary<DeclRef<Decl>, LoweredValInfo> declValues; +}; + + struct IRGenContext { - Dictionary<Decl*, LoweredExprInfo> declValues; + SharedIRGenContext* shared; IRBuilder* irBuilder; }; -struct LoweredValInfo +IRValue* getSimpleVal(LoweredValInfo lowered) { -}; + switch(lowered.flavor) + { + case LoweredValInfo::Flavor::None: + return nullptr; + + case LoweredValInfo::Flavor::Simple: + return lowered.val; + + default: + SLANG_UNEXPECTED("unhandled value flavor"); + return nullptr; + } +} struct LoweredTypeInfo { enum class Flavor { - Type, + None, + Simple, }; union @@ -60,8 +83,47 @@ struct LoweredTypeInfo IRType* type; }; Flavor flavor; + + LoweredTypeInfo() + { + flavor = Flavor::None; + } + + LoweredTypeInfo(IRType* t) + { + flavor = Flavor::Simple; + type = t; + } }; +IRType* getSimpleType(LoweredTypeInfo lowered) +{ + switch(lowered.flavor) + { + case LoweredTypeInfo::Flavor::None: + return nullptr; + + case LoweredTypeInfo::Flavor::Simple: + return lowered.type; + + default: + SLANG_UNEXPECTED("unhandled value flavor"); + return nullptr; + } +} + +LoweredValInfo lowerVal( + IRGenContext* context, + Val* val); + +IRValue* lowerSimpleVal( + IRGenContext* context, + Val* val) +{ + auto lowered = lowerVal(context, val); + return getSimpleVal(lowered); +} + LoweredTypeInfo lowerType( IRGenContext* context, Type* type); @@ -73,28 +135,105 @@ static LoweredTypeInfo lowerType( return lowerType(context, type.type); } -LoweredExprInfo lowerExpr( +// Lower a type and expect the result to be simple +IRType* lowerSimpleType( + IRGenContext* context, + Type* type) +{ + auto lowered = lowerType(context, type); + return getSimpleType(lowered); +} + +IRType* lowerSimpleType( + IRGenContext* context, + QualType const& type) +{ + auto lowered = lowerType(context, type); + return getSimpleType(lowered); +} + + +LoweredValInfo lowerExpr( IRGenContext* context, Expr* expr); +void lowerStmt( + IRGenContext* context, + Stmt* stmt); + +LoweredValInfo ensureDecl( + IRGenContext* context, + DeclRef<Decl> const& declRef); + // struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredTypeInfo> { IRGenContext* context; + IRBuilder* getBuilder() { return context->irBuilder; } + LoweredValInfo visitVal(Val* val) { SLANG_UNIMPLEMENTED_X("value lowering"); } + LoweredValInfo visitConstantIntVal(ConstantIntVal* val) + { + // TODO: it is a bit messy here that the `ConstantIntVal` representation + // has no notion of a *type* associated with the value... + + auto type = getBuilder()->getBaseType(BaseType::Int); + return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value)); + } + LoweredTypeInfo visitType(Type* type) { SLANG_UNIMPLEMENTED_X("type lowering"); } + LoweredTypeInfo visitDeclRefType(DeclRefType* type) + { + // Catch-all for user-defined type references + LoweredValInfo loweredDeclRef = ensureDecl(context, type->declRef); + + // TODO: make sure that the value is actually a type... + + switch (loweredDeclRef.flavor) + { + case LoweredValInfo::Flavor::Simple: + return LoweredTypeInfo((IRType*)loweredDeclRef.val); + + default: + SLANG_UNIMPLEMENTED_X("type lowering"); + } + + } + + LoweredTypeInfo visitBasicExpressionType(BasicExpressionType* type) + { + return getBuilder()->getBaseType(type->BaseType); + } + + LoweredTypeInfo visitVectorExpressionType(VectorExpressionType* type) + { + auto irElementType = lowerSimpleType(context, type->elementType); + auto irElementCount = lowerSimpleVal(context, type->elementCount); + + return getBuilder()->getVectorType(irElementType, irElementCount); + } + }; +LoweredValInfo lowerVal( + IRGenContext* context, + Val* val) +{ + ValLoweringVisitor visitor; + visitor.context = context; + return visitor.dispatch(val); +} + LoweredTypeInfo lowerType( IRGenContext* context, Type* type) @@ -115,52 +254,40 @@ struct LoweringVisitor // -struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredExprInfo> +struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> { IRGenContext* context; - LoweredExprInfo visitVarExpr(VarExpr* expr) - { - LoweredExprInfo info; - if(context->declValues.TryGetValue(expr->declRef.getDecl(), info)) - return info; - - throw 99; + IRBuilder* getBuilder() { return context->irBuilder; } - return LoweredExprInfo(); + LoweredValInfo visitVarExpr(VarExpr* expr) + { + LoweredValInfo info = ensureDecl(context, expr->declRef); + return info; } - LoweredExprInfo visitOverloadedExpr(OverloadedExpr* expr) + LoweredValInfo visitOverloadedExpr(OverloadedExpr* expr) { SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); } - LoweredExprInfo visitInitializerListExpr(InitializerListExpr* expr) + LoweredValInfo visitInitializerListExpr(InitializerListExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for initializer list expression"); } - IRType* getIRType(LoweredTypeInfo const& typeInfo) + LoweredValInfo visitConstantExpr(ConstantExpr* expr) { - switch( typeInfo.flavor ) - { - case LoweredTypeInfo::Flavor::Type: - return typeInfo.type; - } - } - - LoweredExprInfo visitConstantExpr(ConstantExpr* expr) - { - auto type = getIRType(lowerType(context, expr->type)); + auto type = lowerSimpleType(context, expr->type); switch( expr->ConstType ) { case ConstantExpr::ConstantType::Bool: - return LoweredExprInfo::createValue(context->irBuilder->getBoolValue(expr->integerValue != 0)); + return LoweredValInfo::simple(context->irBuilder->getBoolValue(expr->integerValue != 0)); case ConstantExpr::ConstantType::Int: - return LoweredExprInfo::createValue(context->irBuilder->getIntValue(type, expr->integerValue)); + return LoweredValInfo::simple(context->irBuilder->getIntValue(type, expr->integerValue)); case ConstantExpr::ConstantType::Float: - return LoweredExprInfo::createValue(context->irBuilder->getFloatValue(type, expr->floatingPointValue)); + return LoweredValInfo::simple(context->irBuilder->getFloatValue(type, expr->floatingPointValue)); case ConstantExpr::ConstantType::String: break; } @@ -168,68 +295,198 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredExprInfo> SLANG_UNEXPECTED("unexpected constant type"); } - LoweredExprInfo visitAggTypeCtorExpr(AggTypeCtorExpr* expr) + LoweredValInfo visitAggTypeCtorExpr(AggTypeCtorExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for aggregate type constructor expression"); } - LoweredExprInfo visitInvokeExpr(InvokeExpr* expr) + void addArgs(List<IRValue*>* ioArgs, LoweredValInfo argInfo) + { + auto& args = *ioArgs; + switch( argInfo.flavor ) + { + case LoweredValInfo::Flavor::Simple: + args.Add(getSimpleVal(argInfo)); + break; + + default: + SLANG_UNIMPLEMENTED_X("addArgs case"); + break; + } + } + + LoweredValInfo lowerIntrinsicCall( + InvokeExpr* expr, + IntrinsicOp intrinsicOp) + { + auto type = lowerSimpleType(context, expr->type); + + List<IRValue*> irArgs; + for( auto arg : expr->Arguments ) + { + auto loweredArg = lowerExpr(context, arg); + addArgs(&irArgs, loweredArg); + } + + UInt argCount = irArgs.Count(); + + return LoweredValInfo::simple(getBuilder()->emitIntrinsicInst(type, intrinsicOp, argCount, &irArgs[0])); + } + + LoweredValInfo lowerSimpleCall(InvokeExpr* expr) { + auto loweredFunc = lowerExpr(context, expr->FunctionExpr); + + for( auto arg : expr->Arguments ) + { + auto loweredArg = lowerExpr(context, arg); + } + SLANG_UNIMPLEMENTED_X("codegen for invoke expression"); } - LoweredExprInfo visitIndexExpr(IndexExpr* expr) + LoweredValInfo visitInvokeExpr(InvokeExpr* expr) + { + // TODO: need to detect calls to builtins here, so that we can expand + // them as their own special opcodes... + + auto funcExpr = expr->FunctionExpr; + if( auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>() ) + { + auto funcDeclRef = funcDeclRefExpr->declRef; + auto funcDecl = funcDeclRef.getDecl(); + if(auto intrinsicOpModifier = funcDecl->FindModifier<IntrinsicOpModifier>()) + { + return lowerIntrinsicCall(expr, intrinsicOpModifier->op); + // + } + // TODO: handle target intrinsic modifier too... + + if( auto ctorDeclRef = funcDeclRef.As<ConstructorDecl>() ) + { + // HACK: we know all constructors are builtins for now, + // so we need to emit them as a call to the corresponding + // builtin operation. + + auto type = lowerSimpleType(context, expr->type); + + List<IRValue*> irArgs; + for( auto arg : expr->Arguments ) + { + auto loweredArg = lowerExpr(context, arg); + addArgs(&irArgs, loweredArg); + } + + UInt argCount = irArgs.Count(); + + return LoweredValInfo::simple(getBuilder()->emitConstructorInst(type, argCount, &irArgs[0])); + } + } + + return lowerSimpleCall(expr); + } + + LoweredValInfo visitIndexExpr(IndexExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for subscript expression"); } - LoweredExprInfo visitMemberExpr(MemberExpr* expr) + LoweredValInfo extractField( + LoweredTypeInfo fieldType, + LoweredValInfo base, + UInt fieldIndex) + { + switch (base.flavor) + { + case LoweredValInfo::Flavor::Simple: + { + IRValue* irBase = base.val; + return LoweredValInfo::simple( + getBuilder()->createFieldExtract( + getSimpleType(fieldType), + irBase, + fieldIndex)); + } + break; + + default: + SLANG_UNIMPLEMENTED_X("codegen for field extract"); + } + } + + LoweredValInfo visitMemberExpr(MemberExpr* expr) { + auto loweredType = lowerType(context, expr->type); + auto loweredBase = lowerExpr(context, expr->BaseExpression); + + auto declRef = expr->declRef; + if (auto fieldDeclRef = declRef.As<StructField>()) + { + // Okay, easy enough: we have a reference to a field of a struct type... + + // HACK: for now just scan the decl to find the right index. + // TODO: we need to deal with the fact that the struct might get + // tuple-ified. + // + UInt index = 0; + for (auto fieldDecl : getMembersOfType<StructField>(fieldDeclRef.GetParent().As<AggTypeDecl>())) + { + if (fieldDecl == fieldDeclRef.getDecl()) + { + break; + } + + index++; + } + + return extractField(loweredType, loweredBase, index); + } + SLANG_UNIMPLEMENTED_X("codegen for subscript expression"); } - LoweredExprInfo visitSwizzleExpr(SwizzleExpr* expr) + LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for swizzle expression"); } - LoweredExprInfo visitDerefExpr(DerefExpr* expr) + LoweredValInfo visitDerefExpr(DerefExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for deref expression"); } - LoweredExprInfo visitTypeCastExpr(TypeCastExpr* expr) + LoweredValInfo visitTypeCastExpr(TypeCastExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for type cast expression"); } - LoweredExprInfo visitSelectExpr(SelectExpr* expr) + LoweredValInfo visitSelectExpr(SelectExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for select expression"); } - LoweredExprInfo visitGenericAppExpr(GenericAppExpr* expr) + LoweredValInfo visitGenericAppExpr(GenericAppExpr* expr) { SLANG_UNIMPLEMENTED_X("generic application expression during code generation"); } - LoweredExprInfo visitSharedTypeExpr(SharedTypeExpr* expr) + LoweredValInfo visitSharedTypeExpr(SharedTypeExpr* expr) { SLANG_UNIMPLEMENTED_X("shared type expression during code generation"); } - LoweredExprInfo visitAssignExpr(AssignExpr* expr) + LoweredValInfo visitAssignExpr(AssignExpr* expr) { SLANG_UNIMPLEMENTED_X("shared type expression during code generation"); } - LoweredExprInfo visitParenExpr(ParenExpr* expr) + LoweredValInfo visitParenExpr(ParenExpr* expr) { return lowerExpr(context, expr->base); } }; -LoweredExprInfo lowerExpr( +LoweredValInfo lowerExpr( IRGenContext* context, Expr* expr) { @@ -238,25 +495,141 @@ LoweredExprInfo lowerExpr( return visitor.dispatch(expr); } -struct LoweredDeclInfo -{}; +struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> +{ + IRGenContext* context; + + IRBuilder* getBuilder() { return context->irBuilder; } -struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredDeclInfo> + void visitStmt(Stmt* stmt) + { + SLANG_UNIMPLEMENTED_X("stmt catch-all"); + } + + void visitBlockStmt(BlockStmt* stmt) + { + lowerStmt(context, stmt->body); + } + + void visitReturnStmt(ReturnStmt* stmt) + { + if( auto expr = stmt->Expression ) + { + auto loweredExpr = lowerExpr(context, expr); + + getBuilder()->createReturn(getSimpleVal(loweredExpr)); + } + else + { + getBuilder()->createReturn(); + } + } +}; + +void lowerStmt( + IRGenContext* context, + Stmt* stmt) +{ + StmtLoweringVisitor visitor; + visitor.context = context; + return visitor.dispatch(stmt); +} + +struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { IRGenContext* context; - LoweredDeclInfo visitDeclBase(DeclBase* decl) + IRBuilder* getBuilder() + { + return context->irBuilder; + } + + LoweredValInfo visitDeclBase(DeclBase* decl) { SLANG_UNIMPLEMENTED_X("decl catch-all"); } - LoweredDeclInfo visitDecl(Decl* decl) + LoweredValInfo visitDecl(Decl* decl) { SLANG_UNIMPLEMENTED_X("decl catch-all"); } + + LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl) + { + // User-defined aggregate type: need to translate into + // a corresponding IR aggregate type. + + List<LoweredTypeInfo> fieldTypes; + List<IRType*> irFieldTypes; + + for (auto fieldDecl : decl->GetFields()) + { + // TODO: need to be prepared to deal with tuple-ness of fields here + auto fieldType = lowerType(context, fieldDecl->getType()); + + fieldTypes.Add(fieldType); + + switch (fieldType.flavor) + { + case LoweredTypeInfo::Flavor::Simple: + irFieldTypes.Add(fieldType.type); + break; + + default: + SLANG_UNIMPLEMENTED_X("struct field type"); + } + } + + // TODO: need to track relationship to original fields... + + IRType* irStructType = getBuilder()->getStructType( + irFieldTypes.Count(), + &irFieldTypes[0]); + + return LoweredValInfo::simple(irStructType); + } + + LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) + { + IRBuilder subBuilderStorage = *getBuilder(); + IRBuilder* subBuilder = &subBuilderStorage; + + // need to create an IR function here + + IRFunc* irFunc = subBuilder->createFunc(); + subBuilder->parentInst = irFunc; + + IRBlock* entryBlock = subBuilder->createBlock(); + subBuilder->parentInst = entryBlock; + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->irBuilder = subBuilder; + + // set up sub context for generating our new function + + for( auto paramDecl : decl->GetParameters() ) + { + IRType* irParamType = lowerSimpleType(context, paramDecl->getType()); + IRParam* irParam = subBuilder->createParam(irParamType); + + DeclRef<ParamDecl> paramDeclRef = makeDeclRef(paramDecl.Ptr()); + + LoweredValInfo irParamVal = LoweredValInfo::simple(irParam); + + subContext->shared->declValues.Add(paramDeclRef, irParamVal); + } + + auto irResultType = lowerType(context, decl->ReturnType); + + + lowerStmt(subContext, decl->Body); + + return LoweredValInfo::simple(irFunc); + } }; -LoweredDeclInfo lowerDecl( +LoweredValInfo lowerDecl( IRGenContext* context, Decl* decl) { @@ -265,6 +638,30 @@ LoweredDeclInfo lowerDecl( return visitor.dispatch(decl); } +LoweredValInfo ensureDecl( + IRGenContext* context, + DeclRef<Decl> const& declRef) +{ + auto shared = context->shared; + + LoweredValInfo result; + if(shared->declValues.TryGetValue(declRef, result)) + return result; + + // TODO: this is where we need to apply any specializations + // from the declaration reference, so that they can be + // applied correctly to the declaration itself... + + IRGenContext subContext = *context; + + result = lowerDecl(context, declRef.getDecl()); + + shared->declValues[declRef] = result; + + return result; +} + + EntryPointLayout* findEntryPointLayout( SharedIRGenContext* shared, EntryPointRequest* entryPointRequest) @@ -299,7 +696,7 @@ static void lowerEntryPointToIR( lowerDecl(context, entryPointFunc); } -void lowerEntryPointToIR( +IRModule* lowerEntryPointToIR( EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target) @@ -307,13 +704,32 @@ void lowerEntryPointToIR( SharedIRGenContext sharedContextStorage; SharedIRGenContext* sharedContext = &sharedContextStorage; + sharedContext->entryPoint = entryPoint; + sharedContext->programLayout = programLayout; + sharedContext->target = target; + IRGenContext contextStorage; IRGenContext* context = &contextStorage; + context->shared = sharedContext; + + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->module = nullptr; + builder->parentInst = nullptr; + + IRModule* module = builder->createModule(); + builder->module = module; + builder->parentInst = module; + + context->irBuilder = builder; + auto entryPointLayout = findEntryPointLayout(sharedContext, entryPoint); lowerEntryPointToIR(context, entryPoint, entryPointLayout); + return module; + } } |
