Skip to content

Latest commit

 

History

History
126 lines (96 loc) · 14.6 KB

RegisterFile.md

File metadata and controls

126 lines (96 loc) · 14.6 KB

Synopsis

union amx_reg { // 64 byte vector...
    // ...of unsigned integers:
    uint8_t  u8 [64];
    uint16_t u16[32];
    uint32_t u32[16];
    // ...of signed integers:
    int8_t   i8 [64];
    int16_t  i16[32];
    int32_t  i32[16];
    // ...of IEEE 754 floating point:
    _Float16 f16[32]; // NB: IEEE half-precision, _not_ BF16
    float    f32[16];
    double   f64[ 8];
};

struct amx_state {
    amx_reg x[ 8]; // 512 bytes, of which 64 bytes extracted / inserted by operations
    amx_reg y[ 8]; // 512 bytes, of which 64 bytes extracted / inserted by operations
    amx_reg z[64]; // 64 by 64 matrix of bytes
}; // 5KB total

Description

Each register is 64 bytes, viewed as vector of u8/u16/u32/i8/i16/i32/f16/f32/f64 elements. The architectural state contains 80 such registers: 8 of which in the X pool, 8 of which in the Y pool, and the remaining 64 forming a 64x64 grid called Z.

The entire X register pool can be concatenated to form a circular buffer of 512 bytes. Most instructions can operate on any contiguous 64 byte range from this circular buffer. The same is true for Y: the entire Y pool can be concatenated to form a circular buffer of 512 bytes, and most instructions can operate on any contiguous 64 byte range from this circular buffer.

Once 64 bytes of X and 64 bytes of Y have been selected, operations between X and Y and Z can be performed. Said operations fall into two main categories:

  • Vector: Select one register from Z, and combine X/Y/Z in a standard SIMD manner: Z[i] += X[i] * Y[i]
  • Matrix: Select a number of registers from Z equal to the number of lanes in X and Y, and combine X/Y/Z in an outer-product manner: Z[j][i] += X[i] * Y[j]

Getting data in to and out of the AMX registers

Load/store instructions move data between memory and AMX registers.

Computation instructions can be used to synthesise various constants in the AMX registers: 0 is easy, as is floating-point -0. The latter can be used with integer shift instructions to synthesise (positive or negative) integer powers of two.

There is no direct movement between AMX registers and A64 general purpose registers or SIMD registers; data has to go via memory.

Indexed loads

By default, instructions operate on a 64-byte span from X or Y. Some operations support indexed loads rather than 64-byte span loads. Said loads are parameterised by two things: the element size and the index size. The element size (ES) is 8/16/32/64 bits, and the index size (IS) is 2/4/5 bits. The element count (EC) is then 512 divided by the element size. A regular load would load an ES * EC (i.e. 512) bit span from X or Y. An indexed load instead loads an IS * EC bit span from X or Y, and then treats every group of IS bits as a lane index into a different register with element size ES. For example, taking ES of 16 for f16 data and IS of 2, a 64-bit span is loaded from X or Y, which can be viewed as u2[32] vector, and is expanded to form an f16[32] vector by looking up into lanes 0/1/2/3 of some other f16[32] vector.

Shuffles

Once a 64 byte X (or Y) vector has been obtained (either by a regular load or an indexed load), some instructions support shuffling the 64 bytes before use.

For vectors of 8 elements (i.e. f64[8]), the four (albeit only three distinct) available shuffles are:

01234567
S001234567
S104152637
S202461357
S301234567

For vectors of 16 elements (i.e. f32[16] or i32[16] or u32[16]), the four available shuffles are:

0123456789101112131415
S00123456789101112131415
S10819210311412513614715
S20481215913261014371115
S30246810121413579111315

For vectors of 32 elements (i.e. f16[32] or i16[32] or u16[32]), the four available shuffles are:

012345678910111213141516171819202122232425262728293031
S0012345678910111213141516171819202122232425262728293031
S1016117218319420521622723824925102611271228132914301531
S2081624191725210182631119274122028513212961422307152331
S3048121620242815913172125292610141822263037111519232731

For vectors of 64 elements (i.e. i8[64] or u8[64]), the four available shuffles are:

0123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
S00123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
S10321332343354365376387398409411042114312441345144615471648174918501951205221532254235524562557265827592860296130623163
S20163248117334921834503193551420365252137536223854723395582440569254157102642581127435912284460132945611430466215314763
S30816243240485619172533414957210182634425058311192735435159412202836445260513212937455361614223038465462715233139475563

In all cases, S0 is the identity, S1 moves lane 1 to lane 2, S2 moves lane 1 to lane 4, and S3 moves lane 1 to lane 8.

Per-byte write-enable

Most instructions support writing to only a subset of the output lanes, leaving the other lanes unchanged. This is controlled by a combination of a mode field and a value field. Said fields typically combine along the lines of:

Mode Meaning of value (N)
0 Write to all lanes (0), or to odd lanes only (1), or to even lanes only (2), or to no lanes
1 Only write lane #N (or for certain vector operations, write all lanes, but broadcast Y lane #N to all lanes of Y)
2 Only write first N lanes, or to all lanes when N is zero
3 Only write last N lanes, or to all lanes when N is zero
4 Only write first N lanes (no lanes when N is zero)
5 Only write last N lanes (no lanes when N is zero)
6 Write to no lanes
7 Write to no lanes

Matrix operations have separate write-enable for the X axis and the Y axis, with the enabled Z elements being the outer product of the two write-enables.

Mixed lane widths

When the element size is identical between X and Y and Z, indexing is simple. Assume an element size in bits (ES) of 8, 16, 32, or 64 for all three, then X and Y have N elements, where N = 512 / ES. In vector mode, a single Z register also has N elements. In matrix mode, a 2D grid of N2 values is used from Z: N distinct registers from Z, each containing N elements. The N distinct registers are equally spaced in the Y dimension, with spacing 64 / N (the user can choose the starting row, subject to 0 ≤ starting row < 64 / N).

When the element sizes are mixed (for example f16 × f16 ↦ f32 or i8 × i16 ↦ i32), then things are more complex. Either more Z registers need to be used (to make space for all the outputs), or some lanes from X and/or Y need to be dropped (because otherwise there is not space for all the outputs), or a combination of both. When lanes are dropped, it is typical to keep just the even lanes, or keep just one lane from every four (i.e. keep lanes 0, 4, 8, etc). Shuffles can be used to select different lanes; for example after applying shuffle S1 and then keeping just the even lanes, the result is lanes 0, 1, 2, etc; and after applying shuffle S2 and then keeping just one lane from every four then the result is lanes 0, 1, 2, etc. Alternatively, byte offsets on the input operands can be used to select different lanes: adding a byte offset equal to one lane turns even lanes into odd lanes, and turns lanes 0, 4, 8, etc into 1, 5, 9, etc.

One particularly common mixed-width combination is X and Y having element size of 16 bits (i.e. i16 or u32 or f16) and Z having element size 32 bits (i.e. i32 or u32 or f32). In this case, both X and Y have 32 elements, and every Z register has 16 elements. The complete outer product of X and Y would need 322 Z values, which there is just space for: use all 64 Z registers, with 16 elements in each. Each 4 by 4 block of bytes ends up looking like:

X0:1X2:3
Y0:1Z0,0:3 += X0:1 × Y0:1
Z1,0:3 += X2:3 × Y0:1
Y2:3Z2,0:3 += X0:1 × Y2:3
Z3,0:3 += X2:3 × Y2:3

An alternative way of viewing this combination is that every pair of Z registers contains 32 lanes (corresponding to the lanes of X), and there are 32 such pairs (corresponding to the lanes of Y), with each pair arranged as:

Z0024681012141618202224262830
Z1135791113151719212325272931

This arrangement is called an interleaved pair of Z registers, and for (16,16,32) has support instructions in the form of ldzi and stzi.