More Shader Gen Stuff

Mostly copied from GLSL since in terms of syntax within blocks they’re pretty similar. Likely the result will need tweaking…

Isn’t that conveniant?

“Do the simd_shuffle”

atomics

Remaining instructions

Remove removed special instructions

Getting somewhere…
This commit is contained in:
Isaac Marovitz 2023-08-03 23:21:22 -04:00 committed by Isaac Marovitz
parent 1790050a14
commit f07327166c
13 changed files with 911 additions and 51 deletions

View file

@ -8,9 +8,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
public const string Tab = " "; public const string Tab = " ";
public StructuredFunction CurrentFunction { get; set; }
public StructuredProgramInfo Info { get; } public StructuredProgramInfo Info { get; }
public ShaderConfig Config { get; } public ShaderConfig Config { get; }
public OperandManager OperandManager { get; }
private readonly StringBuilder _sb; private readonly StringBuilder _sb;
private int _level; private int _level;
@ -22,6 +27,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
Info = info; Info = info;
Config = config; Config = config;
OperandManager = new OperandManager();
_sb = new StringBuilder(); _sb = new StringBuilder();
} }

View file

@ -1,4 +1,6 @@
using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
@ -10,6 +12,46 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
context.AppendLine("#include <simd/simd.h>"); context.AppendLine("#include <simd/simd.h>");
context.AppendLine(); context.AppendLine();
context.AppendLine("using namespace metal;"); context.AppendLine("using namespace metal;");
if ((info.HelperFunctionsMask & HelperFunctionsMask.SwizzleAdd) != 0)
{
}
}
public static void DeclareLocals(CodeGenContext context, StructuredFunction function)
{
foreach (AstOperand decl in function.Locals)
{
string name = context.OperandManager.DeclareLocal(decl);
context.AppendLine(GetVarTypeName(context, decl.VarType) + " " + name + ";");
}
}
public static string GetVarTypeName(CodeGenContext context, AggregateType type)
{
return type switch
{
AggregateType.Void => "void",
AggregateType.Bool => "bool",
AggregateType.FP32 => "float",
AggregateType.S32 => "int",
AggregateType.U32 => "uint",
AggregateType.Vector2 | AggregateType.Bool => "bool2",
AggregateType.Vector2 | AggregateType.FP32 => "float2",
AggregateType.Vector2 | AggregateType.S32 => "int2",
AggregateType.Vector2 | AggregateType.U32 => "uint2",
AggregateType.Vector3 | AggregateType.Bool => "bool3",
AggregateType.Vector3 | AggregateType.FP32 => "float3",
AggregateType.Vector3 | AggregateType.S32 => "int3",
AggregateType.Vector3 | AggregateType.U32 => "uint3",
AggregateType.Vector4 | AggregateType.Bool => "bool4",
AggregateType.Vector4 | AggregateType.FP32 => "float4",
AggregateType.Vector4 | AggregateType.S32 => "int4",
AggregateType.Vector4 | AggregateType.U32 => "uint4",
_ => throw new ArgumentException($"Invalid variable type \"{type}\"."),
};
} }
} }
} }

View file

@ -0,0 +1,15 @@
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
static class DefaultNames
{
public const string LocalNamePrefix = "temp";
public const string PerPatchAttributePrefix = "patchAttr";
public const string IAttributePrefix = "inAttr";
public const string OAttributePrefix = "outAttr";
public const string ArgumentNamePrefix = "a";
public const string UndefinedName = "undef";
}
}

View file

@ -0,0 +1,7 @@
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
static class HelperFunctionNames
{
public static string SwizzleAdd = "helperSwizzleAdd";
}
}

View file

@ -0,0 +1,4 @@
inline bool voteAllEqual(bool value)
{
return simd_all(value) || !simd_any(value);
}

View file

@ -0,0 +1,149 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class InstGen
{
public static string GetExpression(CodeGenContext context, IAstNode node)
{
if (node is AstOperation operation)
{
return GetExpression(context, operation);
}
else if (node is AstOperand operand)
{
return context.OperandManager.GetExpression(context, operand);
}
throw new ArgumentException($"Invalid node type \"{node?.GetType().Name ?? "null"}\".");
}
private static string GetExpression(CodeGenContext context, AstOperation operation)
{
Instruction inst = operation.Inst;
InstInfo info = GetInstructionInfo(inst);
if ((info.Type & InstType.Call) != 0)
{
bool atomic = (info.Type & InstType.Atomic) != 0;
int arity = (int)(info.Type & InstType.ArityMask);
string args = string.Empty;
if (atomic)
{
// Hell
}
else
{
for (int argIndex = 0; argIndex < arity; argIndex++)
{
if (argIndex != 0)
{
args += ", ";
}
AggregateType dstType = GetSrcVarType(inst, argIndex);
args += GetSourceExpr(context, operation.GetSource(argIndex), dstType);
}
}
return info.OpName + '(' + args + ')';
}
else if ((info.Type & InstType.Op) != 0)
{
string op = info.OpName;
if (inst == Instruction.Return && operation.SourcesCount != 0)
{
return $"{op} {GetSourceExpr(context, operation.GetSource(0), context.CurrentFunction.ReturnType)}";
}
int arity = (int)(info.Type & InstType.ArityMask);
string[] expr = new string[arity];
for (int index = 0; index < arity; index++)
{
IAstNode src = operation.GetSource(index);
string srcExpr = GetSourceExpr(context, src, GetSrcVarType(inst, index));
bool isLhs = arity == 2 && index == 0;
expr[index] = Enclose(srcExpr, src, inst, info, isLhs);
}
switch (arity)
{
case 0:
return op;
case 1:
return op + expr[0];
case 2:
return $"{expr[0]} {op} {expr[1]}";
case 3:
return $"{expr[0]} {op[0]} {expr[1]} {op[1]} {expr[2]}";
}
}
else if ((info.Type & InstType.Special) != 0)
{
switch (inst & Instruction.Mask)
{
case Instruction.Barrier:
return "|| BARRIER ||";
case Instruction.Call:
return "|| CALL ||";
case Instruction.FSIBegin:
return "|| FSI BEGIN ||";
case Instruction.FSIEnd:
return "|| FSI END ||";
case Instruction.FindLSB:
return "|| FIND LSB ||";
case Instruction.FindMSBS32:
return "|| FIND MSB S32 ||";
case Instruction.FindMSBU32:
return "|| FIND MSB U32 ||";
case Instruction.GroupMemoryBarrier:
return "|| FIND GROUP MEMORY BARRIER ||";
case Instruction.ImageLoad:
return "|| IMAGE LOAD ||";
case Instruction.ImageStore:
return "|| IMAGE STORE ||";
case Instruction.ImageAtomic:
return "|| IMAGE ATOMIC ||";
case Instruction.Load:
return "|| LOAD ||";
case Instruction.Lod:
return "|| LOD ||";
case Instruction.MemoryBarrier:
return "|| MEMORY BARRIER ||";
case Instruction.Store:
return "|| STORE ||";
case Instruction.TextureSample:
return "|| TEXTURE SAMPLE ||";
case Instruction.TextureSize:
return "|| TEXTURE SIZE ||";
case Instruction.VectorExtract:
return "|| VECTOR EXTRACT ||";
case Instruction.VoteAllEqual:
return "|| VOTE ALL EQUAL ||";
}
}
throw new InvalidOperationException($"Unexpected instruction type \"{info.Type}\".");
}
}
}

View file

@ -1,4 +1,6 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{ {
@ -11,42 +13,42 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
_infoTable = new InstInfo[(int)Instruction.Count]; _infoTable = new InstInfo[(int)Instruction.Count];
#pragma warning disable IDE0055 // Disable formatting #pragma warning disable IDE0055 // Disable formatting
Add(Instruction.AtomicAdd, InstType.AtomicBinary, "add"); Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomic_add_explicit");
Add(Instruction.AtomicAnd, InstType.AtomicBinary, "and"); Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomic_and_explicit");
Add(Instruction.AtomicCompareAndSwap, 0); Add(Instruction.AtomicCompareAndSwap, InstType.AtomicBinary, "atomic_compare_exchange_weak_explicit");
Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "max"); Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "atomic_max_explicit");
Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "min"); Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "atomic_min_explicit");
Add(Instruction.AtomicOr, InstType.AtomicBinary, "or"); Add(Instruction.AtomicOr, InstType.AtomicBinary, "atomic_or_explicit");
Add(Instruction.AtomicSwap, 0); Add(Instruction.AtomicSwap, InstType.AtomicBinary, "atomic_exchange_explicit");
Add(Instruction.AtomicXor, InstType.AtomicBinary, "xor"); Add(Instruction.AtomicXor, InstType.AtomicBinary, "atomic_xor_explicit");
Add(Instruction.Absolute, InstType.AtomicBinary, "abs"); Add(Instruction.Absolute, InstType.AtomicBinary, "atomic_abs_explicit");
Add(Instruction.Add, InstType.OpBinaryCom, "+"); Add(Instruction.Add, InstType.OpBinaryCom, "+", 2);
Add(Instruction.Ballot, InstType.Special); Add(Instruction.Ballot, InstType.CallUnary, "simd_ballot");
Add(Instruction.Barrier, InstType.Special); Add(Instruction.Barrier, InstType.Special);
Add(Instruction.BitCount, InstType.CallUnary, "popcount"); Add(Instruction.BitCount, InstType.CallUnary, "popcount");
Add(Instruction.BitfieldExtractS32, InstType.CallTernary, "extract_bits"); Add(Instruction.BitfieldExtractS32, InstType.CallTernary, "extract_bits");
Add(Instruction.BitfieldExtractU32, InstType.CallTernary, "extract_bits"); Add(Instruction.BitfieldExtractU32, InstType.CallTernary, "extract_bits");
Add(Instruction.BitfieldInsert, InstType.CallQuaternary, "insert_bits"); Add(Instruction.BitfieldInsert, InstType.CallQuaternary, "insert_bits");
Add(Instruction.BitfieldReverse, InstType.CallUnary, "reverse_bits"); Add(Instruction.BitfieldReverse, InstType.CallUnary, "reverse_bits");
Add(Instruction.BitwiseAnd, InstType.OpBinaryCom, "&"); Add(Instruction.BitwiseAnd, InstType.OpBinaryCom, "&", 6);
Add(Instruction.BitwiseExclusiveOr, InstType.OpBinaryCom, "^"); Add(Instruction.BitwiseExclusiveOr, InstType.OpBinaryCom, "^", 7);
Add(Instruction.BitwiseNot, InstType.OpUnary, "~"); Add(Instruction.BitwiseNot, InstType.OpUnary, "~", 0);
Add(Instruction.BitwiseOr, InstType.OpBinaryCom, "|"); Add(Instruction.BitwiseOr, InstType.OpBinaryCom, "|", 8);
Add(Instruction.Call, InstType.Special); Add(Instruction.Call, InstType.Special);
Add(Instruction.Ceiling, InstType.CallUnary, "ceil"); Add(Instruction.Ceiling, InstType.CallUnary, "ceil");
Add(Instruction.Clamp, InstType.CallTernary, "clamp"); Add(Instruction.Clamp, InstType.CallTernary, "clamp");
Add(Instruction.ClampU32, InstType.CallTernary, "clamp"); Add(Instruction.ClampU32, InstType.CallTernary, "clamp");
Add(Instruction.CompareEqual, InstType.OpBinaryCom, "=="); Add(Instruction.CompareEqual, InstType.OpBinaryCom, "==", 5);
Add(Instruction.CompareGreater, InstType.OpBinary, ">"); Add(Instruction.CompareGreater, InstType.OpBinary, ">", 4);
Add(Instruction.CompareGreaterOrEqual, InstType.OpBinary, ">="); Add(Instruction.CompareGreaterOrEqual, InstType.OpBinary, ">=", 4);
Add(Instruction.CompareGreaterOrEqualU32, InstType.OpBinary, ">="); Add(Instruction.CompareGreaterOrEqualU32, InstType.OpBinary, ">=", 4);
Add(Instruction.CompareGreaterU32, InstType.OpBinary, ">"); Add(Instruction.CompareGreaterU32, InstType.OpBinary, ">", 4);
Add(Instruction.CompareLess, InstType.OpBinary, "<"); Add(Instruction.CompareLess, InstType.OpBinary, "<", 4);
Add(Instruction.CompareLessOrEqual, InstType.OpBinary, "<="); Add(Instruction.CompareLessOrEqual, InstType.OpBinary, "<=", 4);
Add(Instruction.CompareLessOrEqualU32, InstType.OpBinary, "<="); Add(Instruction.CompareLessOrEqualU32, InstType.OpBinary, "<=", 4);
Add(Instruction.CompareLessU32, InstType.OpBinary, "<"); Add(Instruction.CompareLessU32, InstType.OpBinary, "<", 4);
Add(Instruction.CompareNotEqual, InstType.OpBinaryCom, "!="); Add(Instruction.CompareNotEqual, InstType.OpBinaryCom, "!=", 5);
Add(Instruction.ConditionalSelect, InstType.OpTernary, "?:"); Add(Instruction.ConditionalSelect, InstType.OpTernary, "?:", 12);
Add(Instruction.ConvertFP32ToFP64, 0); // MSL does not have a 64-bit FP Add(Instruction.ConvertFP32ToFP64, 0); // MSL does not have a 64-bit FP
Add(Instruction.ConvertFP64ToFP32, 0); // MSL does not have a 64-bit FP Add(Instruction.ConvertFP64ToFP32, 0); // MSL does not have a 64-bit FP
Add(Instruction.ConvertFP32ToS32, InstType.Cast, "int"); Add(Instruction.ConvertFP32ToS32, InstType.Cast, "int");
@ -61,7 +63,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
Add(Instruction.Ddx, InstType.CallUnary, "dfdx"); Add(Instruction.Ddx, InstType.CallUnary, "dfdx");
Add(Instruction.Ddy, InstType.CallUnary, "dfdy"); Add(Instruction.Ddy, InstType.CallUnary, "dfdy");
Add(Instruction.Discard, InstType.CallNullary, "discard_fragment"); Add(Instruction.Discard, InstType.CallNullary, "discard_fragment");
Add(Instruction.Divide, InstType.OpBinary, "/"); Add(Instruction.Divide, InstType.OpBinary, "/", 1);
Add(Instruction.EmitVertex, 0); // MSL does not have geometry shaders Add(Instruction.EmitVertex, 0); // MSL does not have geometry shaders
Add(Instruction.EndPrimitive, 0); // MSL does not have geometry shaders Add(Instruction.EndPrimitive, 0); // MSL does not have geometry shaders
Add(Instruction.ExponentB2, InstType.CallUnary, "exp2"); Add(Instruction.ExponentB2, InstType.CallUnary, "exp2");
@ -81,10 +83,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
Add(Instruction.Load, InstType.Special); Add(Instruction.Load, InstType.Special);
Add(Instruction.Lod, InstType.Special); Add(Instruction.Lod, InstType.Special);
Add(Instruction.LogarithmB2, InstType.CallUnary, "log2"); Add(Instruction.LogarithmB2, InstType.CallUnary, "log2");
Add(Instruction.LogicalAnd, InstType.OpBinaryCom, "&&"); Add(Instruction.LogicalAnd, InstType.OpBinaryCom, "&&", 9);
Add(Instruction.LogicalExclusiveOr, InstType.OpBinaryCom, "^"); Add(Instruction.LogicalExclusiveOr, InstType.OpBinaryCom, "^", 10);
Add(Instruction.LogicalNot, InstType.OpUnary, "!"); Add(Instruction.LogicalNot, InstType.OpUnary, "!", 0);
Add(Instruction.LogicalOr, InstType.OpBinaryCom, "||"); Add(Instruction.LogicalOr, InstType.OpBinaryCom, "||", 11);
Add(Instruction.LoopBreak, InstType.OpNullary, "break"); Add(Instruction.LoopBreak, InstType.OpNullary, "break");
Add(Instruction.LoopContinue, InstType.OpNullary, "continue"); Add(Instruction.LoopContinue, InstType.OpNullary, "continue");
Add(Instruction.PackDouble2x32, 0); // MSL does not have a 64-bit FP Add(Instruction.PackDouble2x32, 0); // MSL does not have a 64-bit FP
@ -95,27 +97,25 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
Add(Instruction.Minimum, InstType.CallBinary, "min"); Add(Instruction.Minimum, InstType.CallBinary, "min");
Add(Instruction.MinimumU32, InstType.CallBinary, "min"); Add(Instruction.MinimumU32, InstType.CallBinary, "min");
Add(Instruction.Modulo, InstType.CallBinary, "%"); Add(Instruction.Modulo, InstType.CallBinary, "%");
Add(Instruction.Multiply, InstType.OpBinaryCom, "*"); Add(Instruction.Multiply, InstType.OpBinaryCom, "*", 1);
Add(Instruction.MultiplyHighS32, InstType.Special); Add(Instruction.MultiplyHighS32, InstType.CallBinary, "mulhi");
Add(Instruction.MultiplyHighU32, InstType.Special); Add(Instruction.MultiplyHighU32, InstType.CallBinary, "mulhi");
Add(Instruction.Negate, InstType.Special); Add(Instruction.Negate, InstType.OpUnary, "-");
Add(Instruction.ReciprocalSquareRoot, InstType.CallUnary, "rsqrt"); Add(Instruction.ReciprocalSquareRoot, InstType.CallUnary, "rsqrt");
Add(Instruction.Return, InstType.OpNullary, "return"); Add(Instruction.Return, InstType.OpNullary, "return");
Add(Instruction.Round, InstType.CallUnary, "round"); Add(Instruction.Round, InstType.CallUnary, "round");
Add(Instruction.ShiftLeft, InstType.OpBinary, "<<"); Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3);
Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>"); Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3);
Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>"); Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3);
// TODO: Shuffle funcs Add(Instruction.Shuffle, InstType.CallQuaternary, "simd_shuffle");
Add(Instruction.Shuffle, 0); Add(Instruction.ShuffleDown, InstType.CallQuaternary, "simd_shuffle_down");
Add(Instruction.ShuffleDown, 0); Add(Instruction.ShuffleUp, InstType.CallQuaternary, "simd_shuffle_up");
Add(Instruction.ShuffleUp, 0); Add(Instruction.ShuffleXor, InstType.CallQuaternary, "simd_shuffle_xor");
Add(Instruction.ShuffleXor, 0);
Add(Instruction.Sine, InstType.CallUnary, "sin"); Add(Instruction.Sine, InstType.CallUnary, "sin");
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt"); Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
Add(Instruction.Store, InstType.Special); Add(Instruction.Store, InstType.Special);
Add(Instruction.Subtract, InstType.OpBinary, "-"); Add(Instruction.Subtract, InstType.OpBinary, "-", 2);
// TODO: Swizzle add Add(Instruction.SwizzleAdd, InstType.CallTernary, HelperFunctionNames.SwizzleAdd);
Add(Instruction.SwizzleAdd, InstType.Special);
Add(Instruction.TextureSample, InstType.Special); Add(Instruction.TextureSample, InstType.Special);
Add(Instruction.TextureSize, InstType.Special); Add(Instruction.TextureSize, InstType.Special);
Add(Instruction.Truncate, InstType.CallUnary, "trunc"); Add(Instruction.Truncate, InstType.CallUnary, "trunc");
@ -123,15 +123,100 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
Add(Instruction.UnpackHalf2x16, InstType.CallUnary, "unpack_unorm2x16_to_half"); Add(Instruction.UnpackHalf2x16, InstType.CallUnary, "unpack_unorm2x16_to_half");
Add(Instruction.VectorExtract, InstType.Special); Add(Instruction.VectorExtract, InstType.Special);
Add(Instruction.VoteAll, InstType.CallUnary, "simd_all"); Add(Instruction.VoteAll, InstType.CallUnary, "simd_all");
// TODO: https://github.com/KhronosGroup/SPIRV-Cross/blob/bccaa94db814af33d8ef05c153e7c34d8bd4d685/reference/shaders-msl/comp/shader_group_vote.msl21.comp#L9
Add(Instruction.VoteAllEqual, InstType.Special); Add(Instruction.VoteAllEqual, InstType.Special);
Add(Instruction.VoteAny, InstType.CallUnary, "simd_any"); Add(Instruction.VoteAny, InstType.CallUnary, "simd_any");
#pragma warning restore IDE0055 #pragma warning restore IDE0055
} }
private static void Add(Instruction inst, InstType flags, string opName = null) private static void Add(Instruction inst, InstType flags, string opName = null, int precedence = 0)
{ {
_infoTable[(int)inst] = new InstInfo(flags, opName); _infoTable[(int)inst] = new InstInfo(flags, opName, precedence);
}
public static InstInfo GetInstructionInfo(Instruction inst)
{
return _infoTable[(int)(inst & Instruction.Mask)];
}
public static string GetSourceExpr(CodeGenContext context, IAstNode node, AggregateType dstType)
{
// TODO: Implement this
// return ReinterpretCast(context, node, OperandManager.GetNodeDestType(context, node), dstType);
return "";
}
public static string Enclose(string expr, IAstNode node, Instruction pInst, bool isLhs)
{
InstInfo pInfo = GetInstructionInfo(pInst);
return Enclose(expr, node, pInst, pInfo, isLhs);
}
public static string Enclose(string expr, IAstNode node, Instruction pInst, InstInfo pInfo, bool isLhs = false)
{
if (NeedsParenthesis(node, pInst, pInfo, isLhs))
{
expr = "(" + expr + ")";
}
return expr;
}
public static bool NeedsParenthesis(IAstNode node, Instruction pInst, InstInfo pInfo, bool isLhs)
{
// If the node isn't a operation, then it can only be a operand,
// and those never needs to be surrounded in parenthesis.
if (node is not AstOperation operation)
{
// This is sort of a special case, if this is a negative constant,
// and it is consumed by a unary operation, we need to put on the parenthesis,
// as in MSL, while a sequence like ~-1 is valid, --2 is not.
if (IsNegativeConst(node) && pInfo.Type == InstType.OpUnary)
{
return true;
}
return false;
}
if ((pInfo.Type & (InstType.Call | InstType.Special)) != 0)
{
return false;
}
InstInfo info = _infoTable[(int)(operation.Inst & Instruction.Mask)];
if ((info.Type & (InstType.Call | InstType.Special)) != 0)
{
return false;
}
if (info.Precedence < pInfo.Precedence)
{
return false;
}
if (info.Precedence == pInfo.Precedence && isLhs)
{
return false;
}
if (pInst == operation.Inst && info.Type == InstType.OpBinaryCom)
{
return false;
}
return true;
}
private static bool IsNegativeConst(IAstNode node)
{
if (node is not AstOperand operand)
{
return false;
}
return operand.Type == OperandType.Constant && operand.Value < 0;
} }
} }
} }

View file

@ -6,10 +6,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
public string OpName { get; } public string OpName { get; }
public InstInfo(InstType type, string opName) public int Precedence { get; }
public InstInfo(InstType type, string opName, int precedence)
{ {
Type = type; Type = type;
OpName = opName; OpName = opName;
Precedence = precedence;
} }
} }
} }

View file

@ -1,6 +1,10 @@
using Ryujinx.Common.Logging; using Ryujinx.Common.Logging;
using Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions;
using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation; using Ryujinx.Graphics.Shader.Translation;
using System;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.TypeConversion;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
@ -18,7 +22,178 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
Declarations.Declare(context, info); Declarations.Declare(context, info);
if (info.Functions.Count != 0)
{
for (int i = 1; i < info.Functions.Count; i++)
{
context.AppendLine($"{GetFunctionSignature(context, info.Functions[i], config.Stage)};");
}
context.AppendLine();
for (int i = 1; i < info.Functions.Count; i++)
{
PrintFunction(context, info.Functions[i], config.Stage);
context.AppendLine();
}
}
PrintFunction(context, info.Functions[0], config.Stage, true);
return context.GetCode(); return context.GetCode();
} }
private static void PrintFunction(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false)
{
context.CurrentFunction = function;
context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc));
context.EnterScope();
Declarations.DeclareLocals(context, function);
PrintBlock(context, function.MainBlock, isMainFunc);
context.LeaveScope();
}
private static string GetFunctionSignature(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false)
{
string[] args = new string[function.InArguments.Length + function.OutArguments.Length];
for (int i = 0; i < function.InArguments.Length; i++)
{
args[i] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
}
for (int i = 0; i < function.OutArguments.Length; i++)
{
int j = i + function.InArguments.Length;
// Likely need to be made into pointers
args[j] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}";
}
string funcKeyword = "inline";
string funcName = null;
if (isMainFunc)
{
if (stage == ShaderStage.Vertex)
{
funcKeyword = "vertex";
funcName = "vertexMain";
}
else if (stage == ShaderStage.Fragment)
{
funcKeyword = "fragment";
funcName = "fragmentMain";
}
}
return $"{funcKeyword} {Declarations.GetVarTypeName(context, function.ReturnType)} {funcName ?? function.Name}({string.Join(", ", args)})";
}
private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction)
{
AstBlockVisitor visitor = new(block);
visitor.BlockEntered += (sender, e) =>
{
switch (e.Block.Type)
{
case AstBlockType.DoWhile:
context.AppendLine("do");
break;
case AstBlockType.Else:
context.AppendLine("else");
break;
case AstBlockType.ElseIf:
context.AppendLine($"else if ({GetCondExpr(context, e.Block.Condition)})");
break;
case AstBlockType.If:
context.AppendLine($"if ({GetCondExpr(context, e.Block.Condition)})");
break;
default:
throw new InvalidOperationException($"Found unexpected block type \"{e.Block.Type}\".");
}
context.EnterScope();
};
visitor.BlockLeft += (sender, e) =>
{
context.LeaveScope();
if (e.Block.Type == AstBlockType.DoWhile)
{
context.AppendLine($"while ({GetCondExpr(context, e.Block.Condition)});");
}
};
bool supportsBarrierDivergence = context.Config.GpuAccessor.QueryHostSupportsShaderBarrierDivergence();
bool mayHaveReturned = false;
foreach (IAstNode node in visitor.Visit())
{
if (node is AstOperation operation)
{
if (!supportsBarrierDivergence)
{
if (operation.Inst == IntermediateRepresentation.Instruction.Barrier)
{
// Barrier on divergent control flow paths may cause the GPU to hang,
// so skip emitting the barrier for those cases.
if (visitor.Block.Type != AstBlockType.Main || mayHaveReturned || !isMainFunction)
{
context.Config.GpuAccessor.Log($"Shader has barrier on potentially divergent block, the barrier will be removed.");
continue;
}
}
else if (operation.Inst == IntermediateRepresentation.Instruction.Return)
{
mayHaveReturned = true;
}
}
string expr = InstGen.GetExpression(context, operation);
if (expr != null)
{
context.AppendLine(expr + ";");
}
}
else if (node is AstAssignment assignment)
{
AggregateType dstType = OperandManager.GetNodeDestType(context, assignment.Destination);
AggregateType srcType = OperandManager.GetNodeDestType(context, assignment.Source);
string dest = InstGen.GetExpression(context, assignment.Destination);
string src = ReinterpretCast(context, assignment.Source, srcType, dstType);
context.AppendLine(dest + " = " + src + ";");
}
else if (node is AstComment comment)
{
context.AppendLine("// " + comment.Comment);
}
else
{
throw new InvalidOperationException($"Found unexpected node type \"{node?.GetType().Name ?? "null"}\".");
}
}
}
private static string GetCondExpr(CodeGenContext context, IAstNode cond)
{
AggregateType srcType = OperandManager.GetNodeDestType(context, cond);
return ReinterpretCast(context, cond, srcType, AggregateType.Bool);
}
} }
} }

View file

@ -0,0 +1,102 @@
using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Globalization;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
static class NumberFormatter
{
private const int MaxDecimal = 256;
public static bool TryFormat(int value, AggregateType dstType, out string formatted)
{
if (dstType == AggregateType.FP32)
{
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
}
else if (dstType == AggregateType.S32)
{
formatted = FormatInt(value);
}
else if (dstType == AggregateType.U32)
{
formatted = FormatUint((uint)value);
}
else if (dstType == AggregateType.Bool)
{
formatted = value != 0 ? "true" : "false";
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
return true;
}
public static string FormatFloat(float value)
{
if (!TryFormatFloat(value, out string formatted))
{
throw new ArgumentException("Failed to convert float value to string.");
}
return formatted;
}
public static bool TryFormatFloat(float value, out string formatted)
{
if (float.IsNaN(value) || float.IsInfinity(value))
{
formatted = null;
return false;
}
formatted = value.ToString("F9", CultureInfo.InvariantCulture);
if (!formatted.Contains('.'))
{
formatted += ".0f";
}
return true;
}
public static string FormatInt(int value, AggregateType dstType)
{
if (dstType == AggregateType.S32)
{
return FormatInt(value);
}
else if (dstType == AggregateType.U32)
{
return FormatUint((uint)value);
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
}
public static string FormatInt(int value)
{
if (value <= MaxDecimal && value >= -MaxDecimal)
{
return value.ToString(CultureInfo.InvariantCulture);
}
return "0x" + value.ToString("X", CultureInfo.InvariantCulture);
}
public static string FormatUint(uint value)
{
if (value <= MaxDecimal && value >= 0)
{
return value.ToString(CultureInfo.InvariantCulture) + "u";
}
return "0x" + value.ToString("X", CultureInfo.InvariantCulture) + "u";
}
}
}

View file

@ -0,0 +1,173 @@
using Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
class OperandManager
{
private readonly Dictionary<AstOperand, string> _locals;
public OperandManager()
{
_locals = new Dictionary<AstOperand, string>();
}
public string DeclareLocal(AstOperand operand)
{
string name = $"{DefaultNames.LocalNamePrefix}_{_locals.Count}";
_locals.Add(operand, name);
return name;
}
public string GetExpression(CodeGenContext context, AstOperand operand)
{
return operand.Type switch
{
OperandType.Argument => GetArgumentName(operand.Value),
OperandType.Constant => NumberFormatter.FormatInt(operand.Value),
OperandType.LocalVariable => _locals[operand],
OperandType.Undefined => DefaultNames.UndefinedName,
_ => throw new ArgumentException($"Invalid operand type \"{operand.Type}\"."),
};
}
public static string GetArgumentName(int argIndex)
{
return $"{DefaultNames.ArgumentNamePrefix}{argIndex}";
}
public static AggregateType GetNodeDestType(CodeGenContext context, IAstNode node)
{
// TODO: Get rid of that function entirely and return the type from the operation generation
// functions directly, like SPIR-V does.
if (node is AstOperation operation)
{
if (operation.Inst == Instruction.Load || operation.Inst.IsAtomic())
{
switch (operation.StorageKind)
{
case StorageKind.ConstantBuffer:
case StorageKind.StorageBuffer:
if (operation.GetSource(0) is not AstOperand bindingIndex || bindingIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
if (operation.GetSource(1) is not AstOperand fieldIndex || fieldIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"Second input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
BufferDefinition buffer = operation.StorageKind == StorageKind.ConstantBuffer
? context.Config.Properties.ConstantBuffers[bindingIndex.Value]
: context.Config.Properties.StorageBuffers[bindingIndex.Value];
StructureField field = buffer.Type.Fields[fieldIndex.Value];
return field.Type & AggregateType.ElementTypeMask;
case StorageKind.LocalMemory:
case StorageKind.SharedMemory:
if (operation.GetSource(0) is not AstOperand { Type: OperandType.Constant } bindingId)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
MemoryDefinition memory = operation.StorageKind == StorageKind.LocalMemory
? context.Config.Properties.LocalMemories[bindingId.Value]
: context.Config.Properties.SharedMemories[bindingId.Value];
return memory.Type & AggregateType.ElementTypeMask;
case StorageKind.Input:
case StorageKind.InputPerPatch:
case StorageKind.Output:
case StorageKind.OutputPerPatch:
if (operation.GetSource(0) is not AstOperand varId || varId.Type != OperandType.Constant)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
IoVariable ioVariable = (IoVariable)varId.Value;
bool isOutput = operation.StorageKind == StorageKind.Output || operation.StorageKind == StorageKind.OutputPerPatch;
bool isPerPatch = operation.StorageKind == StorageKind.InputPerPatch || operation.StorageKind == StorageKind.OutputPerPatch;
int location = 0;
int component = 0;
if (context.Config.HasPerLocationInputOrOutput(ioVariable, isOutput))
{
if (operation.GetSource(1) is not AstOperand vecIndex || vecIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"Second input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
location = vecIndex.Value;
if (operation.SourcesCount > 2 &&
operation.GetSource(2) is AstOperand elemIndex &&
elemIndex.Type == OperandType.Constant &&
context.Config.HasPerLocationInputOrOutputComponent(ioVariable, location, elemIndex.Value, isOutput))
{
component = elemIndex.Value;
}
}
(_, AggregateType varType) = IoMap.GetMSLBuiltIn(ioVariable);
return varType & AggregateType.ElementTypeMask;
}
}
else if (operation.Inst == Instruction.Call)
{
AstOperand funcId = (AstOperand)operation.GetSource(0);
Debug.Assert(funcId.Type == OperandType.Constant);
return context.GetFunction(funcId.Value).ReturnType;
}
else if (operation.Inst == Instruction.VectorExtract)
{
return GetNodeDestType(context, operation.GetSource(0)) & ~AggregateType.ElementCountMask;
}
else if (operation is AstTextureOperation texOp)
{
if (texOp.Inst == Instruction.ImageLoad ||
texOp.Inst == Instruction.ImageStore ||
texOp.Inst == Instruction.ImageAtomic)
{
return texOp.GetVectorType(texOp.Format.GetComponentType());
}
else if (texOp.Inst == Instruction.TextureSample)
{
return texOp.GetVectorType(GetDestVarType(operation.Inst));
}
}
return GetDestVarType(operation.Inst);
}
else if (node is AstOperand operand)
{
if (operand.Type == OperandType.Argument)
{
int argIndex = operand.Value;
return context.CurrentFunction.GetArgumentType(argIndex);
}
return OperandInfo.GetVarType(operand);
}
else
{
throw new ArgumentException($"Invalid node type \"{node?.GetType().Name ?? "null"}\".");
}
}
}
}

View file

@ -0,0 +1,95 @@
using Ryujinx.Common.Logging;
using Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
static class TypeConversion
{
public static string ReinterpretCast(
CodeGenContext context,
IAstNode node,
AggregateType srcType,
AggregateType dstType)
{
if (node is AstOperand operand && operand.Type == OperandType.Constant)
{
if (NumberFormatter.TryFormat(operand.Value, dstType, out string formatted))
{
return formatted;
}
}
string expr = InstGen.GetExpression(context, node);
return ReinterpretCast(expr, node, srcType, dstType);
}
private static string ReinterpretCast(string expr, IAstNode node, AggregateType srcType, AggregateType dstType)
{
if (srcType == dstType)
{
return expr;
}
if (srcType == AggregateType.FP32)
{
switch (dstType)
{
case AggregateType.Bool:
return $"(floatBitsToInt({expr}) != 0)";
case AggregateType.S32:
return $"floatBitsToInt({expr})";
case AggregateType.U32:
return $"floatBitsToUint({expr})";
}
}
else if (dstType == AggregateType.FP32)
{
switch (srcType)
{
case AggregateType.Bool:
return $"intBitsToFloat({ReinterpretBoolToInt(expr, node, AggregateType.S32)})";
case AggregateType.S32:
return $"intBitsToFloat({expr})";
case AggregateType.U32:
return $"uintBitsToFloat({expr})";
}
}
else if (srcType == AggregateType.Bool)
{
return ReinterpretBoolToInt(expr, node, dstType);
}
else if (dstType == AggregateType.Bool)
{
expr = InstGenHelper.Enclose(expr, node, Instruction.CompareNotEqual, isLhs: true);
return $"({expr} != 0)";
}
else if (dstType == AggregateType.S32)
{
return $"int({expr})";
}
else if (dstType == AggregateType.U32)
{
return $"uint({expr})";
}
Logger.Warning?.Print(LogClass.Gpu, $"Invalid reinterpret cast from \"{srcType}\" to \"{dstType}\".");
// TODO: Make this an error again
return $"INVALID CAST ({expr})";
}
private static string ReinterpretBoolToInt(string expr, IAstNode node, AggregateType dstType)
{
string trueExpr = NumberFormatter.FormatInt(IrConsts.True, dstType);
string falseExpr = NumberFormatter.FormatInt(IrConsts.False, dstType);
expr = InstGenHelper.Enclose(expr, node, Instruction.ConditionalSelect, isLhs: false);
return $"({expr} ? {trueExpr} : {falseExpr})";
}
}
}

View file

@ -14,5 +14,8 @@
<EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\MultiplyHighU32.glsl" /> <EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\MultiplyHighU32.glsl" />
<EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\SwizzleAdd.glsl" /> <EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\SwizzleAdd.glsl" />
</ItemGroup> </ItemGroup>
<ItemGroup>
<EmbeddedResource Include="CodeGen\Msl\HelperFunctions\VoteAllEqual.metal" />
</ItemGroup>
</Project> </Project>