Ryujinx/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs
riperiperi 6096dd75b2 Handle the case where byte optionValues are sent to BSD
Some games and the Mario Odyssey Multiplayer mod do this.

The SMO multiplayer mod also needs you to revert #3394 as it uses a blocking socket to receive (otherwise it hangs), and it doesn't seem to like being forced as non-blocking.
2022-06-18 22:08:47 +01:00

338 lines
10 KiB
C#

using Ryujinx.Common.Logging;
using System;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{
class ManagedSocket : ISocket
{
public int Refcount { get; set; }
public AddressFamily AddressFamily => Socket.AddressFamily;
public SocketType SocketType => Socket.SocketType;
public ProtocolType ProtocolType => Socket.ProtocolType;
public bool Blocking { get => Socket.Blocking; set => Socket.Blocking = value; }
public IntPtr Handle => Socket.Handle;
public IPEndPoint RemoteEndPoint => Socket.RemoteEndPoint as IPEndPoint;
public IPEndPoint LocalEndPoint => Socket.LocalEndPoint as IPEndPoint;
public Socket Socket { get; }
public ManagedSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
Socket = new Socket(addressFamily, socketType, protocolType);
Refcount = 1;
}
private ManagedSocket(Socket socket)
{
Socket = socket;
Refcount = 1;
}
private static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags)
{
SocketFlags socketFlags = SocketFlags.None;
if (bsdSocketFlags.HasFlag(BsdSocketFlags.Oob))
{
socketFlags |= SocketFlags.OutOfBand;
}
if (bsdSocketFlags.HasFlag(BsdSocketFlags.Peek))
{
socketFlags |= SocketFlags.Peek;
}
if (bsdSocketFlags.HasFlag(BsdSocketFlags.DontRoute))
{
socketFlags |= SocketFlags.DontRoute;
}
if (bsdSocketFlags.HasFlag(BsdSocketFlags.Trunc))
{
socketFlags |= SocketFlags.Truncated;
}
if (bsdSocketFlags.HasFlag(BsdSocketFlags.CTrunc))
{
socketFlags |= SocketFlags.ControlDataTruncated;
}
bsdSocketFlags &= ~(BsdSocketFlags.Oob |
BsdSocketFlags.Peek |
BsdSocketFlags.DontRoute |
BsdSocketFlags.DontWait |
BsdSocketFlags.Trunc |
BsdSocketFlags.CTrunc);
if (bsdSocketFlags != BsdSocketFlags.None)
{
Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported socket flags: {bsdSocketFlags}");
}
return socketFlags;
}
public LinuxError Accept(out ISocket newSocket)
{
try
{
newSocket = new ManagedSocket(Socket.Accept());
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
newSocket = null;
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError Bind(IPEndPoint localEndPoint)
{
try
{
Socket.Bind(localEndPoint);
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public void Close()
{
Socket.Close();
}
public LinuxError Connect(IPEndPoint remoteEndPoint)
{
try
{
Socket.Connect(remoteEndPoint);
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
if (!Blocking && exception.ErrorCode == (int)WsaError.WSAEWOULDBLOCK)
{
return LinuxError.EINPROGRESS;
}
else
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
}
public void Disconnect()
{
Socket.Disconnect(true);
}
public void Dispose()
{
Socket.Close();
Socket.Dispose();
}
public LinuxError Listen(int backlog)
{
try
{
Socket.Listen(backlog);
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public bool Poll(int microSeconds, SelectMode mode)
{
return Socket.Poll(microSeconds, mode);
}
public LinuxError Shutdown(BsdSocketShutdownFlags how)
{
try
{
Socket.Shutdown((SocketShutdown)how);
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError Receive(out int receiveSize, Span<byte> buffer, BsdSocketFlags flags)
{
try
{
receiveSize = Socket.Receive(buffer, ConvertBsdSocketFlags(flags));
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
receiveSize = -1;
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError ReceiveFrom(out int receiveSize, Span<byte> buffer, int size, BsdSocketFlags flags, out IPEndPoint remoteEndPoint)
{
remoteEndPoint = new IPEndPoint(IPAddress.Any, 0);
LinuxError result;
bool shouldBlockAfterOperation = false;
try
{
EndPoint temp = new IPEndPoint(IPAddress.Any, 0);
if (Blocking && flags.HasFlag(BsdSocketFlags.DontWait))
{
Blocking = false;
shouldBlockAfterOperation = true;
}
receiveSize = Socket.ReceiveFrom(buffer[..size], ConvertBsdSocketFlags(flags), ref temp);
remoteEndPoint = (IPEndPoint)temp;
result = LinuxError.SUCCESS;
}
catch (SocketException exception)
{
receiveSize = -1;
result = WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
if (shouldBlockAfterOperation)
{
Blocking = true;
}
return result;
}
public LinuxError Send(out int sendSize, ReadOnlySpan<byte> buffer, BsdSocketFlags flags)
{
try
{
sendSize = Socket.Send(buffer, ConvertBsdSocketFlags(flags));
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
sendSize = -1;
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError SendTo(out int sendSize, ReadOnlySpan<byte> buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint)
{
try
{
sendSize = Socket.SendTo(buffer[..size], ConvertBsdSocketFlags(flags), remoteEndPoint);
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
sendSize = -1;
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span<byte> optionValue)
{
try
{
if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName))
{
Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported GetSockOpt Option: {option} Level: {level}");
return LinuxError.EOPNOTSUPP;
}
byte[] tempOptionValue = new byte[optionValue.Length];
Socket.GetSocketOption(level, optionName, tempOptionValue);
tempOptionValue.AsSpan().CopyTo(optionValue);
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan<byte> optionValue)
{
try
{
if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName))
{
Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported SetSockOpt Option: {option} Level: {level}");
return LinuxError.EOPNOTSUPP;
}
int value = optionValue.Length >= 4 ? MemoryMarshal.Read<int>(optionValue) : MemoryMarshal.Read<byte>(optionValue);
if (option == BsdSocketOption.SoLinger)
{
int value2 = MemoryMarshal.Read<int>(optionValue[4..]);
Socket.SetSocketOption(level, SocketOptionName.Linger, new LingerOption(value != 0, value2));
}
else
{
Socket.SetSocketOption(level, optionName, value);
}
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError Read(out int readSize, Span<byte> buffer)
{
return Receive(out readSize, buffer, BsdSocketFlags.None);
}
public LinuxError Write(out int writeSize, ReadOnlySpan<byte> buffer)
{
return Send(out writeSize, buffer, BsdSocketFlags.None);
}
}
}