HWBTutorials/aes-performance/aes-intrinsic/aes.cpp

145 lines
4.6 KiB
C++

#include <chrono>
#include <csignal>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <emmintrin.h>
#include <iostream>
#include <stdint.h>
#include <stdlib.h>
#include <immintrin.h>
#include <wmmintrin.h>
/* AES-128 simple implementation template and testing */
/*
Author: Manuel Thalmann, thalmman@fit.cvut.cz
Template: Jiri Bucek 2017
AES specification:
http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf
*/
/* AES Constants */
// AES polynomial
const uint16_t POLYNOMIAL = 0b100011011;
__m128i computeKey(__m128i key, __m128i expansionSource) {
__m128i tmp1 = _mm_shuffle_epi32(expansionSource, 0xFF);
__m128i tmp2;
_mm_storeu_si128(&tmp2, key);
key = _mm_slli_si128(key, 1 * 4);
tmp2 = _mm_xor_si128(tmp2, key);
key = _mm_slli_si128(key, 1 * 4);
tmp2 = _mm_xor_si128(tmp2, key);
key = _mm_slli_si128(key, 1 * 4);
tmp2 = _mm_xor_si128(tmp2, key);
key = _mm_slli_si128(key, 1 * 4);
tmp2 = _mm_xor_si128(tmp2, key);
return tmp1 ^ tmp2;
}
void addKey(uint8_t index, __m128i expKey[11], __m128i expSource) {
_mm_storeu_si128(&expKey[index], computeKey(expKey[index - 1], expSource));
}
/*
* Key expansion from 128bits (4*32b)
* to 11 round keys (11*4*32b)
* each round key is 4*32b
*/
// Taken from: https://www.brainkart.com/article/AES-Key-Expansion_8410/
void expandKey(__m128i key, __m128i expKey[11]) {
__m128i expSource;
_mm_storeu_si128(&expKey[0], key);
expSource = _mm_aeskeygenassist_si128(expKey[0], 0x01);
addKey(1, expKey, expSource);
expSource = _mm_aeskeygenassist_si128(expKey[1], 0x02);
addKey(2, expKey, expSource);
expSource = _mm_aeskeygenassist_si128(expKey[2], 0x04);
addKey(3, expKey, expSource);
expSource = _mm_aeskeygenassist_si128(expKey[3], 0x08);
addKey(4, expKey, expSource);
expSource = _mm_aeskeygenassist_si128(expKey[4], 0x10);
addKey(5, expKey, expSource);
expSource = _mm_aeskeygenassist_si128(expKey[5], 0x20);
addKey(6, expKey, expSource);
expSource = _mm_aeskeygenassist_si128(expKey[6], 0x40);
addKey(7, expKey, expSource);
expSource = _mm_aeskeygenassist_si128(expKey[7], 0x80);
addKey(8, expKey, expSource);
expSource = _mm_aeskeygenassist_si128(expKey[8], 0x1B);
addKey(9, expKey, expSource);
expSource = _mm_aeskeygenassist_si128(expKey[9], 0x36);
addKey(10, expKey, expSource);
}
void aes(__m128i *value, __m128i key)
{
//... Initialize ...
__m128i expKey[11];
__m128i tmp = _mm_load_si128(value);
expandKey(key, expKey);
tmp = _mm_xor_si128(tmp, expKey[0]);
tmp = _mm_aesenc_si128(tmp, expKey[1]);
tmp = _mm_aesenc_si128(tmp, expKey[2]);
tmp = _mm_aesenc_si128(tmp, expKey[3]);
tmp = _mm_aesenc_si128(tmp, expKey[4]);
tmp = _mm_aesenc_si128(tmp, expKey[5]);
tmp = _mm_aesenc_si128(tmp, expKey[6]);
tmp = _mm_aesenc_si128(tmp, expKey[7]);
tmp = _mm_aesenc_si128(tmp, expKey[8]);
tmp = _mm_aesenc_si128(tmp, expKey[9]);
tmp = _mm_aesenclast_si128(tmp, expKey[10]);
_mm_storeu_si128(value, tmp);
}
//****************************
// MAIN function: AES testing
//****************************
int main(int argc, char* argv[])
{
uint32_t cycles = 1000000;
__m128i key = _mm_setr_epi8(0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff);
__m128i value = _mm_setr_epi8(0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89);
uint8_t expected[16] = { 0x1d, 0x07, 0x34, 0x40, 0xeb, 0xbe, 0x24, 0xc5, 0x02, 0x8b, 0xd8, 0x02, 0x65, 0xc8, 0xfb, 0x1d };
if (argc > 2) {
std::cerr << "Invalid number of arguments\n";
exit(EXIT_FAILURE);
} else if (argc == 2) {
cycles = std::atoi(argv[1]);
}
const auto start{std::chrono::steady_clock::now()};
{
for (int i = 0; i < cycles; i++) {
aes(&value, key);
}
}
const auto end{std::chrono::steady_clock::now()};
const std::chrono::duration<double> elapsed_seconds{end - start};
auto milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(elapsed_seconds).count();
std::cout << "AES (" << cycles << " runs)\nElapsed time: ";
std::cout << milliseconds << "ms\n"; // Before C++20
if (cycles == 1000000) {
for (int i = 0; i < 16; i++) {
if (((uint8_t *)&value)[i] != expected[i]) {
std::cout << "Mismatch at out[" << i << "]!\n";
exit(EXIT_FAILURE);
}
}
std::cout << "Validation successful!\n";
} else {
std::cout << "No results for " << cycles << " cycles precomputed. No validation.\n";
}
exit(value[0]);
}