Protect against stack overflow caused by deep recursive calls

This commit is contained in:
gdk 2022-06-29 15:46:45 -03:00 committed by Mary
parent 2cdc82cb91
commit 2291819c41
7 changed files with 72 additions and 10 deletions

View file

@ -19,7 +19,7 @@ namespace ARMeilleure.Instructions
context.LoadFromContext(); context.LoadFromContext();
context.Return(Const(op.Address)); InstEmitFlowHelper.EmitReturn(context, Const(op.Address));
} }
public static void Svc(ArmEmitterContext context) public static void Svc(ArmEmitterContext context)
@ -49,7 +49,7 @@ namespace ARMeilleure.Instructions
context.LoadFromContext(); context.LoadFromContext();
context.Return(Const(op.Address)); InstEmitFlowHelper.EmitReturn(context, Const(op.Address));
} }
} }
} }

View file

@ -33,7 +33,7 @@ namespace ARMeilleure.Instructions
context.LoadFromContext(); context.LoadFromContext();
context.Return(Const(context.CurrOp.Address)); InstEmitFlowHelper.EmitReturn(context, Const(context.CurrOp.Address));
} }
} }
} }

View file

@ -66,7 +66,7 @@ namespace ARMeilleure.Instructions
{ {
OpCodeBReg op = (OpCodeBReg)context.CurrOp; OpCodeBReg op = (OpCodeBReg)context.CurrOp;
context.Return(GetIntOrZR(context, op.Rn)); EmitReturn(context, GetIntOrZR(context, op.Rn));
} }
public static void Tbnz(ArmEmitterContext context) => EmitTb(context, onNotZero: true); public static void Tbnz(ArmEmitterContext context) => EmitTb(context, onNotZero: true);

View file

@ -12,6 +12,10 @@ namespace ARMeilleure.Instructions
{ {
static class InstEmitFlowHelper static class InstEmitFlowHelper
{ {
// How many calls we can have in our call stack before we give up and return to the dispatcher.
// This prevents stack overflows caused by deep recursive calls.
private const int MaxCallDepth = 200;
public static void EmitCondBranch(ArmEmitterContext context, Operand target, Condition cond) public static void EmitCondBranch(ArmEmitterContext context, Operand target, Condition cond)
{ {
if (cond != Condition.Al) if (cond != Condition.Al)
@ -163,6 +167,19 @@ namespace ARMeilleure.Instructions
{ {
if (isReturn) if (isReturn)
{ {
EmitReturn(context, target);
}
else
{
EmitTableBranch(context, target, isJump: true);
}
}
public static void EmitReturn(ArmEmitterContext context, Operand target)
{
Operand nativeContext = context.LoadArgument(OperandType.I64, 0);
DecreaseCallDepth(context, nativeContext);
if (target.Type == OperandType.I32) if (target.Type == OperandType.I32)
{ {
target = context.ZeroExtend32(OperandType.I64, target); target = context.ZeroExtend32(OperandType.I64, target);
@ -170,11 +187,6 @@ namespace ARMeilleure.Instructions
context.Return(target); context.Return(target);
} }
else
{
EmitTableBranch(context, target, isJump: true);
}
}
private static void EmitTableBranch(ArmEmitterContext context, Operand guestAddress, bool isJump) private static void EmitTableBranch(ArmEmitterContext context, Operand guestAddress, bool isJump)
{ {
@ -218,6 +230,8 @@ namespace ARMeilleure.Instructions
{ {
OpCode op = context.CurrOp; OpCode op = context.CurrOp;
EmitCallDepthCheckAndIncrement(context, nativeContext, guestAddress);
Operand returnAddress = context.Call(hostAddress, OperandType.I64, nativeContext); Operand returnAddress = context.Call(hostAddress, OperandType.I64, nativeContext);
context.LoadFromContext(); context.LoadFromContext();
@ -233,8 +247,41 @@ namespace ARMeilleure.Instructions
Operand lblContinue = context.GetLabel(nextAddr.Value); Operand lblContinue = context.GetLabel(nextAddr.Value);
context.BranchIf(lblContinue, returnAddress, nextAddr, Comparison.Equal, BasicBlockFrequency.Cold); context.BranchIf(lblContinue, returnAddress, nextAddr, Comparison.Equal, BasicBlockFrequency.Cold);
DecreaseCallDepth(context, nativeContext);
context.Return(returnAddress); context.Return(returnAddress);
} }
} }
private static void EmitCallDepthCheckAndIncrement(EmitterContext context, Operand nativeContext, Operand guestAddress)
{
if (!Optimizations.EnableDeepCallRecursionProtection)
{
return;
}
Operand callDepthAddr = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset()));
Operand currentCallDepth = context.Load(OperandType.I32, callDepthAddr);
Operand lblDoCall = Label();
context.BranchIf(lblDoCall, currentCallDepth, Const(MaxCallDepth), Comparison.LessUI);
context.Store(callDepthAddr, context.Subtract(currentCallDepth, Const(1)));
context.Return(guestAddress);
context.MarkLabel(lblDoCall);
context.Store(callDepthAddr, context.Add(currentCallDepth, Const(1)));
}
private static void DecreaseCallDepth(EmitterContext context, Operand nativeContext)
{
if (!Optimizations.EnableDeepCallRecursionProtection)
{
return;
}
Operand callDepthAddr = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset()));
Operand currentCallDepth = context.Load(OperandType.I32, callDepthAddr);
context.Store(callDepthAddr, context.Subtract(currentCallDepth, Const(1)));
}
} }
} }

View file

@ -9,6 +9,7 @@ namespace ARMeilleure
public static bool AllowLcqInFunctionTable { get; set; } = true; public static bool AllowLcqInFunctionTable { get; set; } = true;
public static bool UseUnmanagedDispatchLoop { get; set; } = true; public static bool UseUnmanagedDispatchLoop { get; set; } = true;
public static bool EnableDeepCallRecursionProtection { get; set; } = true;
public static bool UseAdvSimdIfAvailable { get; set; } = true; public static bool UseAdvSimdIfAvailable { get; set; } = true;
public static bool UseArm64AesIfAvailable { get; set; } = true; public static bool UseArm64AesIfAvailable { get; set; } = true;

View file

@ -21,6 +21,7 @@ namespace ARMeilleure.State
public ulong ExclusiveValueLow; public ulong ExclusiveValueLow;
public ulong ExclusiveValueHigh; public ulong ExclusiveValueHigh;
public int Running; public int Running;
public int CallDepth;
} }
private static NativeCtxStorage _dummyStorage = new(); private static NativeCtxStorage _dummyStorage = new();
@ -257,6 +258,11 @@ namespace ARMeilleure.State
return StorageOffset(ref _dummyStorage, ref _dummyStorage.Running); return StorageOffset(ref _dummyStorage, ref _dummyStorage.Running);
} }
public static int GetCallDepthOffset()
{
return StorageOffset(ref _dummyStorage, ref _dummyStorage.CallDepth);
}
private static int StorageOffset<T>(ref NativeCtxStorage storage, ref T target) private static int StorageOffset<T>(ref NativeCtxStorage storage, ref T target)
{ {
return (int)Unsafe.ByteOffset(ref Unsafe.As<NativeCtxStorage, T>(ref storage), ref target); return (int)Unsafe.ByteOffset(ref Unsafe.As<NativeCtxStorage, T>(ref storage), ref target);

View file

@ -262,10 +262,18 @@ namespace ARMeilleure.Translation
Operand runningAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetRunningOffset())); Operand runningAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetRunningOffset()));
Operand dispatchAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetDispatchAddressOffset())); Operand dispatchAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetDispatchAddressOffset()));
Operand callDepthAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset()));
EmitSyncFpContext(context, nativeContext, true); EmitSyncFpContext(context, nativeContext, true);
context.MarkLabel(beginLbl); context.MarkLabel(beginLbl);
if (Optimizations.EnableDeepCallRecursionProtection)
{
// Reset the call depth counter, since this is our first guest function call.
context.Store(callDepthAddress, Const(1));
}
context.Store(dispatchAddress, guestAddress); context.Store(dispatchAddress, guestAddress);
context.Copy(guestAddress, context.Call(Const((ulong)DispatchStub), OperandType.I64, nativeContext)); context.Copy(guestAddress, context.Call(Const((ulong)DispatchStub), OperandType.I64, nativeContext));
context.BranchIfFalse(endLbl, guestAddress); context.BranchIfFalse(endLbl, guestAddress);