summaryrefslogtreecommitdiff
path: root/source/slang/lower-to-ir.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/lower-to-ir.cpp')
-rw-r--r--source/slang/lower-to-ir.cpp208
1 files changed, 180 insertions, 28 deletions
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 388ca884e..b1bc63fa1 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -1657,6 +1657,127 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
return lowerSubExpr(expr->base);
}
+ LoweredValInfo getSimpleDefaultVal(IRType* type)
+ {
+ if(auto basicType = as<IRBasicType>(type))
+ {
+ switch( basicType->getBaseType() )
+ {
+ default:
+ SLANG_UNEXPECTED("missing case for getting IR default value");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ break;
+
+ case BaseType::Bool:
+ case BaseType::Int8:
+ case BaseType::Int16:
+ case BaseType::Int:
+ case BaseType::Int64:
+ case BaseType::UInt8:
+ case BaseType::UInt16:
+ case BaseType::UInt:
+ case BaseType::UInt64:
+ return LoweredValInfo::simple(getBuilder()->getIntValue(type, 0));
+
+ case BaseType::Half:
+ case BaseType::Float:
+ case BaseType::Double:
+ return LoweredValInfo::simple(getBuilder()->getFloatValue(type, 0.0));
+ }
+ }
+
+ SLANG_UNEXPECTED("missing case for getting IR default value");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+
+ LoweredValInfo getDefaultVal(Type* type)
+ {
+ auto irType = lowerType(context, type);
+ if (auto basicType = type->As<BasicExpressionType>())
+ {
+ return getSimpleDefaultVal(irType);
+ }
+ else if (auto vectorType = type->As<VectorExpressionType>())
+ {
+ UInt elementCount = (UInt) GetIntVal(vectorType->elementCount);
+
+ auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType));
+
+ List<IRInst*> args;
+ for(UInt ee = 0; ee < elementCount; ++ee)
+ {
+ args.Add(irDefaultValue);
+ }
+ return LoweredValInfo::simple(
+ getBuilder()->emitMakeVector(irType, args.Count(), args.Buffer()));
+ }
+ else if (auto matrixType = type->As<MatrixExpressionType>())
+ {
+ UInt rowCount = (UInt) GetIntVal(matrixType->getRowCount());
+
+ auto rowType = matrixType->getRowType();
+
+ auto irDefaultValue = getSimpleVal(context, getDefaultVal(rowType));
+
+ List<IRInst*> args;
+ for(UInt rr = 0; rr < rowCount; ++rr)
+ {
+ args.Add(irDefaultValue);
+ }
+ return LoweredValInfo::simple(
+ getBuilder()->emitMakeMatrix(irType, args.Count(), args.Buffer()));
+ }
+ else if (auto arrayType = type->As<ArrayExpressionType>())
+ {
+ UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength);
+
+ auto irDefaultElement = getSimpleVal(context, getDefaultVal(arrayType->baseType));
+
+ List<IRInst*> args;
+ for(UInt ee = 0; ee < elementCount; ++ee)
+ {
+ args.Add(irDefaultElement);
+ }
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitMakeArray(irType, args.Count(), args.Buffer()));
+ }
+ else if (auto declRefType = type->As<DeclRefType>())
+ {
+ DeclRef<Decl> declRef = declRefType->declRef;
+ if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>())
+ {
+ List<IRInst*> args;
+ for (auto ff : getMembersOfType<StructField>(aggTypeDeclRef))
+ {
+ if (ff.getDecl()->HasModifier<HLSLStaticModifier>())
+ continue;
+
+ auto irFieldVal = getSimpleVal(context, getDefaultVal(ff));
+ args.Add(irFieldVal);
+ }
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitMakeStruct(irType, args.Count(), args.Buffer()));
+ }
+ }
+
+ SLANG_UNEXPECTED("unexpected type when creating default value");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+
+ LoweredValInfo getDefaultVal(StructField* decl)
+ {
+ if(auto initExpr = decl->initExpr)
+ {
+ return lowerRValueExpr(context, initExpr);
+ }
+ else
+ {
+ return getDefaultVal(decl->type);
+ }
+ }
+
LoweredValInfo visitInitializerListExpr(InitializerListExpr* expr)
{
// Allocate a temporary of the given type
@@ -1666,23 +1787,33 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
UInt argCount = expr->args.Count();
+ // If the initializer list was empty, then the user was
+ // asking for default initialization, which should apply
+ // to (almost) any type.
+ //
+ if(argCount == 0)
+ {
+ return getDefaultVal(type.type);
+ }
+
// Now for each argument in the initializer list,
// fill in the appropriate field of the result
if (auto arrayType = type->As<ArrayExpressionType>())
{
UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength);
- for (UInt ee = 0; ee < elementCount; ++ee)
+ for (UInt ee = 0; ee < argCount; ++ee)
{
- if (ee < argCount)
- {
- auto argExpr = expr->args[ee];
- LoweredValInfo argVal = lowerRValueExpr(context, argExpr);
- args.Add(getSimpleVal(context, argVal));
- }
- else
+ auto argExpr = expr->args[ee];
+ LoweredValInfo argVal = lowerRValueExpr(context, argExpr);
+ args.Add(getSimpleVal(context, argVal));
+ }
+ if(elementCount > argCount)
+ {
+ auto irDefaultValue = getSimpleVal(context, getDefaultVal(arrayType->baseType));
+ for(UInt ee = argCount; ee < elementCount; ++ee)
{
- SLANG_UNEXPECTED("need to default-initialize array elements");
+ args.Add(irDefaultValue);
}
}
@@ -1692,25 +1823,48 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
else if (auto vectorType = type->As<VectorExpressionType>())
{
UInt elementCount = (UInt) GetIntVal(vectorType->elementCount);
- UInt argCounter = 0;
- for (UInt ee = 0; ee < elementCount; ++ee)
+ for (UInt ee = 0; ee < argCount; ++ee)
{
- UInt argIndex = argCounter++;
- if (argIndex < argCount)
+ auto argExpr = expr->args[ee];
+ LoweredValInfo argVal = lowerRValueExpr(context, argExpr);
+ args.Add(getSimpleVal(context, argVal));
+ }
+ if(elementCount > argCount)
+ {
+ auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType));
+ for(UInt ee = argCount; ee < elementCount; ++ee)
{
- auto argExpr = expr->args[argIndex];
- LoweredValInfo argVal = lowerRValueExpr(context, argExpr);
- args.Add(getSimpleVal(context, argVal));
+ args.Add(irDefaultValue);
}
- else
+ }
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitMakeVector(irType, args.Count(), args.Buffer()));
+ }
+ else if (auto matrixType = type->As<MatrixExpressionType>())
+ {
+ UInt rowCount = (UInt) GetIntVal(matrixType->getRowCount());
+
+ for (UInt rr = 0; rr < argCount; ++rr)
+ {
+ auto argExpr = expr->args[rr];
+ LoweredValInfo argVal = lowerRValueExpr(context, argExpr);
+ args.Add(getSimpleVal(context, argVal));
+ }
+ if(rowCount > argCount)
+ {
+ auto rowType = matrixType->getRowType();
+ auto irDefaultValue = getSimpleVal(context, getDefaultVal(rowType));
+
+ for(UInt rr = argCount; rr < rowCount; ++rr)
{
- SLANG_UNEXPECTED("need to default-initialize vector elements");
+ args.Add(irDefaultValue);
}
}
return LoweredValInfo::simple(
- getBuilder()->emitMakeVector(irType, args.Count(), args.Buffer()));
+ getBuilder()->emitMakeMatrix(irType, args.Count(), args.Buffer()));
}
else if (auto declRefType = type->As<DeclRefType>())
{
@@ -1732,23 +1886,21 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
else
{
- SLANG_UNEXPECTED("need to default-initialize struct fields");
+ auto irDefaultValue = getSimpleVal(context, getDefaultVal(ff));
+ args.Add(irDefaultValue);
}
}
return LoweredValInfo::simple(
getBuilder()->emitMakeStruct(irType, args.Count(), args.Buffer()));
}
- else
- {
- SLANG_UNEXPECTED("not sure how to initialize this type");
- }
- }
- else
- {
- SLANG_UNEXPECTED("not sure how to initialize this type");
}
+ // If none of the above cases matched, then we had better
+ // have zero arguments in the initailizer list, in which
+ // case we are just looking for default initialization.
+ //
+ SLANG_UNEXPECTED("unhandled case for initializer list codegen");
UNREACHABLE_RETURN(LoweredValInfo());
}