diff options
Diffstat (limited to 'source/slang/lower-to-ir.cpp')
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 208 |
1 files changed, 180 insertions, 28 deletions
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 388ca884e..b1bc63fa1 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -1657,6 +1657,127 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return lowerSubExpr(expr->base); } + LoweredValInfo getSimpleDefaultVal(IRType* type) + { + if(auto basicType = as<IRBasicType>(type)) + { + switch( basicType->getBaseType() ) + { + default: + SLANG_UNEXPECTED("missing case for getting IR default value"); + UNREACHABLE_RETURN(LoweredValInfo()); + break; + + case BaseType::Bool: + case BaseType::Int8: + case BaseType::Int16: + case BaseType::Int: + case BaseType::Int64: + case BaseType::UInt8: + case BaseType::UInt16: + case BaseType::UInt: + case BaseType::UInt64: + return LoweredValInfo::simple(getBuilder()->getIntValue(type, 0)); + + case BaseType::Half: + case BaseType::Float: + case BaseType::Double: + return LoweredValInfo::simple(getBuilder()->getFloatValue(type, 0.0)); + } + } + + SLANG_UNEXPECTED("missing case for getting IR default value"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo getDefaultVal(Type* type) + { + auto irType = lowerType(context, type); + if (auto basicType = type->As<BasicExpressionType>()) + { + return getSimpleDefaultVal(irType); + } + else if (auto vectorType = type->As<VectorExpressionType>()) + { + UInt elementCount = (UInt) GetIntVal(vectorType->elementCount); + + auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType)); + + List<IRInst*> args; + for(UInt ee = 0; ee < elementCount; ++ee) + { + args.Add(irDefaultValue); + } + return LoweredValInfo::simple( + getBuilder()->emitMakeVector(irType, args.Count(), args.Buffer())); + } + else if (auto matrixType = type->As<MatrixExpressionType>()) + { + UInt rowCount = (UInt) GetIntVal(matrixType->getRowCount()); + + auto rowType = matrixType->getRowType(); + + auto irDefaultValue = getSimpleVal(context, getDefaultVal(rowType)); + + List<IRInst*> args; + for(UInt rr = 0; rr < rowCount; ++rr) + { + args.Add(irDefaultValue); + } + return LoweredValInfo::simple( + getBuilder()->emitMakeMatrix(irType, args.Count(), args.Buffer())); + } + else if (auto arrayType = type->As<ArrayExpressionType>()) + { + UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength); + + auto irDefaultElement = getSimpleVal(context, getDefaultVal(arrayType->baseType)); + + List<IRInst*> args; + for(UInt ee = 0; ee < elementCount; ++ee) + { + args.Add(irDefaultElement); + } + + return LoweredValInfo::simple( + getBuilder()->emitMakeArray(irType, args.Count(), args.Buffer())); + } + else if (auto declRefType = type->As<DeclRefType>()) + { + DeclRef<Decl> declRef = declRefType->declRef; + if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + { + List<IRInst*> args; + for (auto ff : getMembersOfType<StructField>(aggTypeDeclRef)) + { + if (ff.getDecl()->HasModifier<HLSLStaticModifier>()) + continue; + + auto irFieldVal = getSimpleVal(context, getDefaultVal(ff)); + args.Add(irFieldVal); + } + + return LoweredValInfo::simple( + getBuilder()->emitMakeStruct(irType, args.Count(), args.Buffer())); + } + } + + SLANG_UNEXPECTED("unexpected type when creating default value"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo getDefaultVal(StructField* decl) + { + if(auto initExpr = decl->initExpr) + { + return lowerRValueExpr(context, initExpr); + } + else + { + return getDefaultVal(decl->type); + } + } + LoweredValInfo visitInitializerListExpr(InitializerListExpr* expr) { // Allocate a temporary of the given type @@ -1666,23 +1787,33 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UInt argCount = expr->args.Count(); + // If the initializer list was empty, then the user was + // asking for default initialization, which should apply + // to (almost) any type. + // + if(argCount == 0) + { + return getDefaultVal(type.type); + } + // Now for each argument in the initializer list, // fill in the appropriate field of the result if (auto arrayType = type->As<ArrayExpressionType>()) { UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength); - for (UInt ee = 0; ee < elementCount; ++ee) + for (UInt ee = 0; ee < argCount; ++ee) { - if (ee < argCount) - { - auto argExpr = expr->args[ee]; - LoweredValInfo argVal = lowerRValueExpr(context, argExpr); - args.Add(getSimpleVal(context, argVal)); - } - else + auto argExpr = expr->args[ee]; + LoweredValInfo argVal = lowerRValueExpr(context, argExpr); + args.Add(getSimpleVal(context, argVal)); + } + if(elementCount > argCount) + { + auto irDefaultValue = getSimpleVal(context, getDefaultVal(arrayType->baseType)); + for(UInt ee = argCount; ee < elementCount; ++ee) { - SLANG_UNEXPECTED("need to default-initialize array elements"); + args.Add(irDefaultValue); } } @@ -1692,25 +1823,48 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> else if (auto vectorType = type->As<VectorExpressionType>()) { UInt elementCount = (UInt) GetIntVal(vectorType->elementCount); - UInt argCounter = 0; - for (UInt ee = 0; ee < elementCount; ++ee) + for (UInt ee = 0; ee < argCount; ++ee) { - UInt argIndex = argCounter++; - if (argIndex < argCount) + auto argExpr = expr->args[ee]; + LoweredValInfo argVal = lowerRValueExpr(context, argExpr); + args.Add(getSimpleVal(context, argVal)); + } + if(elementCount > argCount) + { + auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType)); + for(UInt ee = argCount; ee < elementCount; ++ee) { - auto argExpr = expr->args[argIndex]; - LoweredValInfo argVal = lowerRValueExpr(context, argExpr); - args.Add(getSimpleVal(context, argVal)); + args.Add(irDefaultValue); } - else + } + + return LoweredValInfo::simple( + getBuilder()->emitMakeVector(irType, args.Count(), args.Buffer())); + } + else if (auto matrixType = type->As<MatrixExpressionType>()) + { + UInt rowCount = (UInt) GetIntVal(matrixType->getRowCount()); + + for (UInt rr = 0; rr < argCount; ++rr) + { + auto argExpr = expr->args[rr]; + LoweredValInfo argVal = lowerRValueExpr(context, argExpr); + args.Add(getSimpleVal(context, argVal)); + } + if(rowCount > argCount) + { + auto rowType = matrixType->getRowType(); + auto irDefaultValue = getSimpleVal(context, getDefaultVal(rowType)); + + for(UInt rr = argCount; rr < rowCount; ++rr) { - SLANG_UNEXPECTED("need to default-initialize vector elements"); + args.Add(irDefaultValue); } } return LoweredValInfo::simple( - getBuilder()->emitMakeVector(irType, args.Count(), args.Buffer())); + getBuilder()->emitMakeMatrix(irType, args.Count(), args.Buffer())); } else if (auto declRefType = type->As<DeclRefType>()) { @@ -1732,23 +1886,21 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } else { - SLANG_UNEXPECTED("need to default-initialize struct fields"); + auto irDefaultValue = getSimpleVal(context, getDefaultVal(ff)); + args.Add(irDefaultValue); } } return LoweredValInfo::simple( getBuilder()->emitMakeStruct(irType, args.Count(), args.Buffer())); } - else - { - SLANG_UNEXPECTED("not sure how to initialize this type"); - } - } - else - { - SLANG_UNEXPECTED("not sure how to initialize this type"); } + // If none of the above cases matched, then we had better + // have zero arguments in the initailizer list, in which + // case we are just looking for default initialization. + // + SLANG_UNEXPECTED("unhandled case for initializer list codegen"); UNREACHABLE_RETURN(LoweredValInfo()); } |
