summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-10-05 12:52:49 -0400
committerGitHub <noreply@github.com>2023-10-05 09:52:49 -0700
commit441e13e13f30b96eb04c05725ad7fe1983c92f53 (patch)
treeaee5c31b62876ef8ad60a37b2a4767b6f1a299c6
parent65751ce222adb302e62b5b7b6312de65638abed5 (diff)
Various AD Fixes (#3263)
* Various fixes * Remove unused parameter * Update slang-ir-loop-unroll.cpp --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/slang-check-expr.cpp14
-rw-r--r--source/slang/slang-emit.cpp3
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp12
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp110
-rw-r--r--source/slang/slang-ir-synthesize-active-mask.cpp59
-rw-r--r--tests/autodiff/generic-differential-synthesis.slang35
-rw-r--r--tests/autodiff/generic-differential-synthesis.slang.expected.txt5
9 files changed, 185 insertions, 59 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index abdd89b01..05a6ed249 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -422,7 +422,9 @@ namespace Slang
return result;
}
- Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(LookupResultItem const& item, Expr* originalExpr)
+ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(
+ LookupResultItem const& item,
+ Expr* originalExpr)
{
// If the only result from lookup is an entry in an interface decl, it could be that
// the user is leaving out an explicit definition for the requirement and depending on
@@ -521,13 +523,16 @@ namespace Slang
conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
conformanceDecl->parentDecl = structDecl;
structDecl->members.add(conformanceDecl);
+ structDecl->parentDecl = parent;
synthesizedDecl = structDecl;
auto typeDef = m_astBuilder->create<TypeAliasDecl>();
typeDef->nameAndLoc.name = getName("Differential");
- auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
- typeDef->type.type = DeclRefType::create(m_astBuilder, declRef);
typeDef->parentDecl = structDecl;
+
+ auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
+
+ typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
structDecl->members.add(typeDef);
}
break;
@@ -545,8 +550,9 @@ namespace Slang
auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
addModifier(synthesizedDecl, toBeSynthesized);
+ auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl);
return ConstructDeclRefExpr(
- makeDeclRef(synthesizedDecl),
+ synthDeclMemberRef,
nullptr,
originalExpr ? originalExpr->loc : SourceLoc(),
originalExpr);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index bbf6885a8..86136a010 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -459,6 +459,9 @@ Result linkAndOptimizeIR(
break;
}
+ if (sink->getErrorCount() != 0)
+ return SLANG_FAIL;
+
// If we have a target that is GPU like we use the string hashing mechanism
// but for that to work we need to inline such that calls (or returns) of strings
// boil down into getStringHash(stringLiteral)
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index bd32a1896..026b8b110 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -552,8 +552,6 @@ INST(TargetSwitch, targetSwitch, 1, 0)
// A generic asm inst has an return semantics that terminates the control flow.
INST(GenericAsm, GenericAsm, 1, 0)
-INST(RequirePrelude, RequirePrelude, 1, 0)
-
INST(discard, discard, 0, 0)
/* IRUnreachable */
@@ -563,6 +561,8 @@ INST_RANGE(Unreachable, MissingReturn, Unreachable)
INST_RANGE(TerminatorInst, Return, Unreachable)
+INST(RequirePrelude, RequirePrelude, 1, 0)
+
// TODO: We should consider splitting the basic arithmetic/comparison
// ops into cases for signed integers, unsigned integers, and floating-point
// values, to better match downstream targets that want to treat them
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 070f989b5..c04450b82 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3040,7 +3040,7 @@ struct IRSPIRVAsm : IRInst
}
};
-struct IRGenericAsm : IRInst
+struct IRGenericAsm : IRTerminatorInst
{
IR_LEAF_ISA(GenericAsm)
UnownedStringSlice getAsm() { return as<IRStringLit>(getOperand(0))->getStringSlice(); }
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index b5af2d974..6970942c9 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -472,15 +472,9 @@ bool unrollLoopsInModule(IRModule* module, DiagnosticSink* sink)
for (auto inst : module->getGlobalInsts())
{
if (auto genFunc = as<IRGeneric>(inst))
- {
- if (auto func = as<IRGlobalValueWithCode>(findGenericReturnVal(genFunc)))
- {
- bool result = unrollLoopsInFunc(module, func, sink);
- if (!result)
- return false;
- }
- }
- else if (auto func = as<IRGlobalValueWithCode>(inst))
+ continue;
+
+ if (auto func = as<IRGlobalValueWithCode>(inst))
{
bool result = unrollLoopsInFunc(module, func, sink);
if (!result)
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index 41665ddf7..3a7e8b9fb 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -177,9 +177,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
}
List<IRInst*> resultElements;
auto elementType = arrayType->getElementType();
+ auto tupleElementType = translateToTupleType(builder, elementType);
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
- auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = builder.emitTargetTupleGetElement(tupleElementType, val, builder.getIntValue(builder.getIntType(), i));
auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
if (!convertedElement)
return nullptr;
@@ -346,7 +347,7 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
List<IRType*> fieldTypes;
for (auto field : as<IRStructType>(type)->getFields())
{
- fieldTypes.add(translateToHostType(builder, field->getFieldType(), func));
+ fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink));
}
auto hostStructType = builder->createStructType();
@@ -358,6 +359,13 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
return hostStructType;
}
+ case kIROp_ArrayType:
+ {
+ auto elementType = translateToHostType(builder, as<IRArrayType>(type)->getElementType(), func, sink);
+ if (!elementType)
+ return nullptr;
+ return builder->getArrayType(elementType, as<IRArrayType>(type)->getElementCount());
+ }
default:
break;
}
@@ -422,13 +430,36 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer());
}
+ case kIROp_ArrayType:
+ {
+ auto cudaArrayType = cast<IRArrayType>(cudaType);
+ auto hostArrayType = cast<IRArrayType>(hostType);
+
+ List<IRInst*> resultElements;
+ for (UInt i = 0; i < (UInt)cast<IRIntLit>(cudaArrayType->getElementCount())->getValue(); i++)
+ {
+ auto cudaElementType = cudaArrayType->getElementType();
+ auto hostElementType = hostArrayType->getElementType();
+ auto castedElement = castHostToCUDAType(
+ builder,
+ hostElementType,
+ cudaElementType,
+ builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i)));
+
+ SLANG_RELEASE_ASSERT(castedElement);
+ resultElements.add(castedElement);
+ }
+
+ return builder->emitMakeArray(cudaType, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
default:
break;
}
- // If translateToHostType worked correctly, we shouldn't get here.
- SLANG_UNREACHABLE("unhandled type");
+ // If translateToHostType worked correctly, there should be no unhandled cases here.
+ // However, we won't diagnose here since its already diagnosed in translateToHostType()
+ return nullptr;
}
void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* hostFunc)
@@ -553,6 +584,12 @@ IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, Diagno
auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink);
if (outType)
*outType = type;
+
+ if (!type || sink->getErrorCount() > 0)
+ {
+ return nullptr;
+ }
+
auto hostParam = builder->emitParam(type);
// Add a namehint to the param by appending the suffix "_host".
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
@@ -600,6 +637,38 @@ void markTypeForPyExport(IRType* type, DiagnosticSink* sink)
}
return;
}
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ IRBuilder builder(arrayType->getModule());
+ if (!arrayType->findDecoration<IRPyExportDecoration>())
+ builder.addPyExportDecoration(arrayType, UnownedStringSlice("Array"));
+
+ markTypeForPyExport(arrayType->getElementType(), sink);
+ return;
+ }
+}
+
+String tryGetExportTypeName(IRBuilder* builder, IRType* type)
+{
+ if (auto structType = as<IRStructType>(type))
+ {
+ if (auto pyExportDecoration = type->findDecoration<IRPyExportDecoration>())
+ return String(pyExportDecoration->getExportName());
+ else
+ return String("");
+ }
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ StringBuilder nameBuilder;
+ nameBuilder << "Array_";
+ nameBuilder << tryGetExportTypeName(builder, arrayType->getElementType());
+ nameBuilder << "_";
+ nameBuilder << cast<IRIntLit>(arrayType->getElementCount())->getValue();
+
+ return nameBuilder.produceString();
+ }
+ else
+ return String();
}
void generateReflectionForType(IRType* type, DiagnosticSink* sink)
@@ -609,7 +678,6 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
// The list will contain the names of all the fields of the type.
//
- // TODO: Fix this to avoid emitting the same type reflection multiple times.
if (!type->findDecoration<IRPyExportDecoration>())
return;
@@ -635,20 +703,32 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
else
fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
- if (!field->getFieldType()->findDecoration<IRPyExportDecoration>())
- {
- fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
- continue;
- }
-
auto fieldType = field->getFieldType();
+ auto exportName = tryGetExportTypeName(&builder, fieldType);
- fieldTypeNames.add(
- builder.emitGetNativeString(
- builder.getStringValue(fieldType->findDecoration<IRPyExportDecoration>()->getExportName())));
+ if (exportName.getLength() > 0)
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(exportName.getUnownedSlice())));
+ else
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
}
break;
}
+ case kIROp_ArrayType:
+ {
+ auto elementType = as<IRArrayType>(type)->getElementType();
+ fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("type"))));
+ fieldTypeNames.add(
+ builder.emitGetNativeString(
+ builder.getStringValue(tryGetExportTypeName(&builder, elementType).getUnownedSlice())));
+
+ auto elementCount = as<IRIntLit>(as<IRArrayType>(type)->getElementCount());
+ fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("size"))));
+
+ StringBuilder elementCountStr;
+ elementCountStr << elementCount->getValue();
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(elementCountStr.getUnownedSlice())));
+ break;
+ }
default:
break;
}
@@ -676,7 +756,7 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
// Set function name.
StringBuilder reflFuncExportName;
- reflFuncExportName << "__typeinfo__" << type->findDecoration<IRPyExportDecoration>()->getExportName();
+ reflFuncExportName << "__typeinfo__" << tryGetExportTypeName(&builder, type).getUnownedSlice();
builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
diff --git a/source/slang/slang-ir-synthesize-active-mask.cpp b/source/slang/slang-ir-synthesize-active-mask.cpp
index 75246d553..60e13b418 100644
--- a/source/slang/slang-ir-synthesize-active-mask.cpp
+++ b/source/slang/slang-ir-synthesize-active-mask.cpp
@@ -1855,37 +1855,40 @@ struct SynthesizeActiveMaskForFunctionContext
}
else if( toBlock->getPredecessors().getCount() > 1 )
{
- // If the target block is one with multiple
- // predecessors, such that it will have an
- // added block parameter (phi node) to select
- // the corect mask value, then we need to
- // pass along the mask value to use as an
- // additional argument on the unconditional branch.
- //
- // If the old unconditional branch was:
- //
- // <op>(arg0, arg1, arg2, ...);
- //
- // Then our new branch will be:
- //
- // <op>(arg0, arg1, arg2, ..., toActiveMask);
- //
- List<IRInst*> newOperands;
- UInt oldOperandCount = terminator->getOperandCount();
- for( UInt i = 0; i < oldOperandCount; ++i )
+ if (doesBlockNeedActiveMask(toBlock))
{
- newOperands.add(terminator->getOperand(i));
- }
- newOperands.add(toActiveMask);
+ // If the target block is one with multiple
+ // predecessors, such that it will have an
+ // added block parameter (phi node) to select
+ // the corect mask value, then we need to
+ // pass along the mask value to use as an
+ // additional argument on the unconditional branch.
+ //
+ // If the old unconditional branch was:
+ //
+ // <op>(arg0, arg1, arg2, ...);
+ //
+ // Then our new branch will be:
+ //
+ // <op>(arg0, arg1, arg2, ..., toActiveMask);
+ //
+ List<IRInst*> newOperands;
+ UInt oldOperandCount = terminator->getOperandCount();
+ for( UInt i = 0; i < oldOperandCount; ++i )
+ {
+ newOperands.add(terminator->getOperand(i));
+ }
+ newOperands.add(toActiveMask);
- IRInst* newTerminator = builder.emitIntrinsicInst(
- terminator->getFullType(),
- terminator->getOp(),
- newOperands.getCount(),
- newOperands.getBuffer());
+ IRInst* newTerminator = builder.emitIntrinsicInst(
+ terminator->getFullType(),
+ terminator->getOp(),
+ newOperands.getCount(),
+ newOperands.getBuffer());
- terminator->replaceUsesWith(newTerminator);
- terminator->removeAndDeallocate();
+ terminator->replaceUsesWith(newTerminator);
+ terminator->removeAndDeallocate();
+ }
}
else
{
diff --git a/tests/autodiff/generic-differential-synthesis.slang b/tests/autodiff/generic-differential-synthesis.slang
new file mode 100644
index 000000000..8c858b9b3
--- /dev/null
+++ b/tests/autodiff/generic-differential-synthesis.slang
@@ -0,0 +1,35 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+__generic<let C : int>
+struct Foo : IDifferentiable
+{
+ float x[C];
+};
+
+[Differentiable]
+Foo<3> getFoo(float x)
+{
+ return { { x, x, x } };
+}
+
+[Differentiable]
+float foobar(float x)
+{
+ int i = 3 * int(floor(x));
+ Foo<3> foo = getFoo(x);
+ return foo.x[i] * foo.x[i];
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ {
+ float a = 0.5;
+ var d = fwd_diff(foobar)(diffPair(a, 1.0)).d;
+ outputBuffer[0] = d;
+ }
+}
diff --git a/tests/autodiff/generic-differential-synthesis.slang.expected.txt b/tests/autodiff/generic-differential-synthesis.slang.expected.txt
new file mode 100644
index 000000000..97de29f1f
--- /dev/null
+++ b/tests/autodiff/generic-differential-synthesis.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+1.000000
+0.000000
+0.000000
+0.000000