From 2291819c41ce3bcdb5cc9074b1377ea98c18466b Mon Sep 17 00:00:00 2001 From: gdk Date: Wed, 29 Jun 2022 15:46:45 -0300 Subject: [PATCH] Protect against stack overflow caused by deep recursive calls --- .../Instructions/InstEmitException.cs | 4 +- .../Instructions/InstEmitException32.cs | 2 +- src/ARMeilleure/Instructions/InstEmitFlow.cs | 2 +- .../Instructions/InstEmitFlowHelper.cs | 59 +++++++++++++++++-- src/ARMeilleure/Optimizations.cs | 1 + src/ARMeilleure/State/NativeContext.cs | 6 ++ .../Translation/TranslatorStubs.cs | 8 +++ 7 files changed, 72 insertions(+), 10 deletions(-) diff --git a/src/ARMeilleure/Instructions/InstEmitException.cs b/src/ARMeilleure/Instructions/InstEmitException.cs index d30fb2fbd..a91716c64 100644 --- a/src/ARMeilleure/Instructions/InstEmitException.cs +++ b/src/ARMeilleure/Instructions/InstEmitException.cs @@ -19,7 +19,7 @@ namespace ARMeilleure.Instructions context.LoadFromContext(); - context.Return(Const(op.Address)); + InstEmitFlowHelper.EmitReturn(context, Const(op.Address)); } public static void Svc(ArmEmitterContext context) @@ -49,7 +49,7 @@ namespace ARMeilleure.Instructions context.LoadFromContext(); - context.Return(Const(op.Address)); + InstEmitFlowHelper.EmitReturn(context, Const(op.Address)); } } } diff --git a/src/ARMeilleure/Instructions/InstEmitException32.cs b/src/ARMeilleure/Instructions/InstEmitException32.cs index ec0c32bf9..d9a01cba6 100644 --- a/src/ARMeilleure/Instructions/InstEmitException32.cs +++ b/src/ARMeilleure/Instructions/InstEmitException32.cs @@ -33,7 +33,7 @@ namespace ARMeilleure.Instructions context.LoadFromContext(); - context.Return(Const(context.CurrOp.Address)); + InstEmitFlowHelper.EmitReturn(context, Const(context.CurrOp.Address)); } } } diff --git a/src/ARMeilleure/Instructions/InstEmitFlow.cs b/src/ARMeilleure/Instructions/InstEmitFlow.cs index a986bf66f..cb214d3d5 100644 --- a/src/ARMeilleure/Instructions/InstEmitFlow.cs +++ b/src/ARMeilleure/Instructions/InstEmitFlow.cs @@ -66,7 +66,7 @@ namespace ARMeilleure.Instructions { OpCodeBReg op = (OpCodeBReg)context.CurrOp; - context.Return(GetIntOrZR(context, op.Rn)); + EmitReturn(context, GetIntOrZR(context, op.Rn)); } public static void Tbnz(ArmEmitterContext context) => EmitTb(context, onNotZero: true); diff --git a/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs b/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs index 2009bafda..f4f7fe4cb 100644 --- a/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs +++ b/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs @@ -12,6 +12,10 @@ namespace ARMeilleure.Instructions { static class InstEmitFlowHelper { + // How many calls we can have in our call stack before we give up and return to the dispatcher. + // This prevents stack overflows caused by deep recursive calls. + private const int MaxCallDepth = 200; + public static void EmitCondBranch(ArmEmitterContext context, Operand target, Condition cond) { if (cond != Condition.Al) @@ -163,12 +167,7 @@ namespace ARMeilleure.Instructions { if (isReturn) { - if (target.Type == OperandType.I32) - { - target = context.ZeroExtend32(OperandType.I64, target); - } - - context.Return(target); + EmitReturn(context, target); } else { @@ -176,6 +175,19 @@ namespace ARMeilleure.Instructions } } + public static void EmitReturn(ArmEmitterContext context, Operand target) + { + Operand nativeContext = context.LoadArgument(OperandType.I64, 0); + DecreaseCallDepth(context, nativeContext); + + if (target.Type == OperandType.I32) + { + target = context.ZeroExtend32(OperandType.I64, target); + } + + context.Return(target); + } + private static void EmitTableBranch(ArmEmitterContext context, Operand guestAddress, bool isJump) { context.StoreToContext(); @@ -218,6 +230,8 @@ namespace ARMeilleure.Instructions { OpCode op = context.CurrOp; + EmitCallDepthCheckAndIncrement(context, nativeContext, guestAddress); + Operand returnAddress = context.Call(hostAddress, OperandType.I64, nativeContext); context.LoadFromContext(); @@ -233,8 +247,41 @@ namespace ARMeilleure.Instructions Operand lblContinue = context.GetLabel(nextAddr.Value); context.BranchIf(lblContinue, returnAddress, nextAddr, Comparison.Equal, BasicBlockFrequency.Cold); + DecreaseCallDepth(context, nativeContext); + context.Return(returnAddress); } } + + private static void EmitCallDepthCheckAndIncrement(EmitterContext context, Operand nativeContext, Operand guestAddress) + { + if (!Optimizations.EnableDeepCallRecursionProtection) + { + return; + } + + Operand callDepthAddr = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset())); + Operand currentCallDepth = context.Load(OperandType.I32, callDepthAddr); + Operand lblDoCall = Label(); + + context.BranchIf(lblDoCall, currentCallDepth, Const(MaxCallDepth), Comparison.LessUI); + context.Store(callDepthAddr, context.Subtract(currentCallDepth, Const(1))); + context.Return(guestAddress); + + context.MarkLabel(lblDoCall); + context.Store(callDepthAddr, context.Add(currentCallDepth, Const(1))); + } + + private static void DecreaseCallDepth(EmitterContext context, Operand nativeContext) + { + if (!Optimizations.EnableDeepCallRecursionProtection) + { + return; + } + + Operand callDepthAddr = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset())); + Operand currentCallDepth = context.Load(OperandType.I32, callDepthAddr); + context.Store(callDepthAddr, context.Subtract(currentCallDepth, Const(1))); + } } } diff --git a/src/ARMeilleure/Optimizations.cs b/src/ARMeilleure/Optimizations.cs index 8fe478e47..10a6b87e2 100644 --- a/src/ARMeilleure/Optimizations.cs +++ b/src/ARMeilleure/Optimizations.cs @@ -9,6 +9,7 @@ namespace ARMeilleure public static bool AllowLcqInFunctionTable { get; set; } = true; public static bool UseUnmanagedDispatchLoop { get; set; } = true; + public static bool EnableDeepCallRecursionProtection { get; set; } = true; public static bool UseAdvSimdIfAvailable { get; set; } = true; public static bool UseArm64AesIfAvailable { get; set; } = true; diff --git a/src/ARMeilleure/State/NativeContext.cs b/src/ARMeilleure/State/NativeContext.cs index 5403042ea..348fd11f3 100644 --- a/src/ARMeilleure/State/NativeContext.cs +++ b/src/ARMeilleure/State/NativeContext.cs @@ -21,6 +21,7 @@ namespace ARMeilleure.State public ulong ExclusiveValueLow; public ulong ExclusiveValueHigh; public int Running; + public int CallDepth; } private static NativeCtxStorage _dummyStorage = new(); @@ -257,6 +258,11 @@ namespace ARMeilleure.State return StorageOffset(ref _dummyStorage, ref _dummyStorage.Running); } + public static int GetCallDepthOffset() + { + return StorageOffset(ref _dummyStorage, ref _dummyStorage.CallDepth); + } + private static int StorageOffset(ref NativeCtxStorage storage, ref T target) { return (int)Unsafe.ByteOffset(ref Unsafe.As(ref storage), ref target); diff --git a/src/ARMeilleure/Translation/TranslatorStubs.cs b/src/ARMeilleure/Translation/TranslatorStubs.cs index eceb1b742..a37231f09 100644 --- a/src/ARMeilleure/Translation/TranslatorStubs.cs +++ b/src/ARMeilleure/Translation/TranslatorStubs.cs @@ -262,10 +262,18 @@ namespace ARMeilleure.Translation Operand runningAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetRunningOffset())); Operand dispatchAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetDispatchAddressOffset())); + Operand callDepthAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset())); EmitSyncFpContext(context, nativeContext, true); context.MarkLabel(beginLbl); + + if (Optimizations.EnableDeepCallRecursionProtection) + { + // Reset the call depth counter, since this is our first guest function call. + context.Store(callDepthAddress, Const(1)); + } + context.Store(dispatchAddress, guestAddress); context.Copy(guestAddress, context.Call(Const((ulong)DispatchStub), OperandType.I64, nativeContext)); context.BranchIfFalse(endLbl, guestAddress);