Users are free to choose between these two options, as well as any intermediate ones (e.g., specifying some of the parameters at creation time while leaving the others until execution time). This enables balancing between flexibility and performance.
#include <cassert>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <iostream>
#include <random>
#include <stdexcept>
#include <vector>
#include "example_utils.hpp"
namespace {
void init_vector(std::vector<float> &v) {
std::mt19937 gen;
std::uniform_real_distribution<float> u(-1, 1);
for (auto &e : v)
e = u(gen);
}
int compare_vectors(const std::vector<float> &v1, const std::vector<float> &v2,
int64_t K, const char *message) {
double v1_l2 = 0, diff_l2 = 0;
for (size_t n = 0; n < v1.size(); ++n) {
float diff = v1[n] - v2[n];
v1_l2 += v1[n] * v1[n];
diff_l2 += diff * diff;
}
v1_l2 = std::sqrt(v1_l2);
diff_l2 = std::sqrt(diff_l2);
const double threshold = std::numeric_limits<float>::epsilon()
* std::log(std::max(2., (double)K));
bool ok = diff_l2 <= threshold * v1_l2;
printf("%s\n\tL2 Norms"
"\n\t\tReference matrix:%g\n\t\tError:%g\n\t\tRelative_error:%g\n"
"\tAccuracy check: %s\n",
message, v1_l2, diff_l2, diff_l2 / v1_l2, ok ? "OK" : "FAILED");
return ok ? 0 : 1;
}
}
int number_of_runs = 1;
float fixed_beta = 0.f;
engine eng(engine::kind::cpu, 0);
matmul dynamic_matmul_create() {
float beta = fixed_beta;
memory::desc a_md(a_shape, memory::data_type::f32, a_strides);
memory::desc b_md(b_shape, memory::data_type::f32, b_strides);
memory::desc c_md(c_shape, memory::data_type::f32, c_strides);
primitive_attr attr;
if (beta != 0.f) {
post_ops po;
po.append_sum(beta);
attr.set_post_ops(po);
}
matmul::desc
matmul_d(a_md, b_md, c_md);
matmul::primitive_desc matmul_pd(matmul_d, attr, eng);
return matmul(matmul_pd);
}
void dynamic_matmul_execute(matmul &matmul_p, char transA, char transB,
int64_t M, int64_t N, int64_t K, float alpha, const float *A,
int64_t lda, const float *B, int64_t ldb, float beta, float *C,
int64_t ldc) {
using dims = memory::dims;
if (beta != fixed_beta)
throw std::logic_error("Run-time beta is not yet supported.");
dims a_strides = tolower(transA) == 'n' ? dims {lda, 1} : dims {1, lda};
dims b_strides = tolower(transB) == 'n' ? dims {ldb, 1} : dims {1, ldb};
memory A_m({{M, K}, memory::data_type::f32, a_strides}, eng, (void *)A);
memory B_m({{K, N}, memory::data_type::f32, b_strides}, eng, (void *)B);
memory C_m({{M, N}, memory::data_type::f32, {ldc, 1}}, eng, (void *)C);
memory alpha_m({{1}, memory::data_type::f32, {1}}, eng, &alpha);
stream s(eng);
matmul_p.execute(s,
s.wait();
}
void static_matmul_create_and_execute(char transA, char transB, int64_t M,
int64_t N, int64_t K, float alpha, const float *A, int64_t lda,
const float *B, int64_t ldb, float beta, float *C, int64_t ldc) {
using dims = memory::dims;
dims a_strides = tolower(transA) == 'n' ? dims {lda, 1} : dims {1, lda};
dims b_strides = tolower(transB) == 'n' ? dims {ldb, 1} : dims {1, ldb};
memory::desc a_md({M, K}, memory::data_type::f32, a_strides);
memory::desc b_md({K, N}, memory::data_type::f32, b_strides);
memory::desc c_md({M, N}, memory::data_type::f32, {ldc, 1});
primitive_attr attr;
if (alpha != 1.f) attr.set_output_scales( 0, {alpha});
if (beta != 0.f) {
post_ops po;
po.append_sum(beta);
attr.set_post_ops(po);
}
matmul::desc
matmul_d(a_md, b_md, c_md);
matmul::primitive_desc matmul_pd(matmul_d, attr, eng);
matmul matmul_p(matmul_pd);
memory A_m(a_md, eng, (void *)A);
memory B_m(b_md, eng, (void *)B);
memory C_m(c_md, eng, (void *)C);
stream s(eng);
matmul_p.execute(s,
s.wait();
}
void sgemm_and_matmul_with_params(char transA, char transB, int64_t M,
int64_t N, int64_t K, float alpha, float beta) {
if (beta != fixed_beta)
throw std::logic_error("Run-time beta is not yet supported.");
std::vector<float> A(M * K);
init_vector(A);
std::vector<float> B(K * N);
init_vector(B);
std::vector<float> C_sgemm(M * N);
init_vector(C_sgemm);
std::vector<float> C_dynamic_matmul = C_sgemm;
std::vector<float> C_static_matmul = C_sgemm;
int64_t lda = tolower(transA) == 'n' ? K : M;
int64_t ldb = tolower(transB) == 'n' ? N : K;
int64_t ldc = N;
for (int run = 0; run < number_of_runs; ++run)
dnnl_sgemm(transA, transB, M, N, K, alpha, A.data(), lda, B.data(), ldb,
beta, C_sgemm.data(), ldc);
auto dynamic_matmul = dynamic_matmul_create();
for (int run = 0; run < number_of_runs; ++run)
dynamic_matmul_execute(dynamic_matmul, transA, transB, M, N, K, alpha,
A.data(), lda, B.data(), ldb, beta, C_dynamic_matmul.data(),
ldc);
for (int run = 0; run < number_of_runs; ++run)
static_matmul_create_and_execute(transA, transB, M, N, K, alpha,
A.data(), lda, B.data(), ldb, beta, C_static_matmul.data(),
ldc);
int rc = 0;
rc |= compare_vectors(
C_sgemm, C_dynamic_matmul, K, "Compare SGEMM vs dynamic MatMul");
if (rc) throw std::logic_error("The resulting matrices diverged too much.");
rc |= compare_vectors(
C_sgemm, C_static_matmul, K, "Compare SGEMM vs static MatMul");
if (rc) throw std::logic_error("The resulting matrices diverged too much.");
}
void sgemm_and_matmul() {
sgemm_and_matmul_with_params('N', 'T', 10, 20, 30, 1.1f, fixed_beta);
}
int main(int argc, char **argv) {
return handle_example_errors({engine::kind::cpu}, sgemm_and_matmul);
}
dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc)
Performs single-precision matrix-matrix multiply.
#define DNNL_RUNTIME_DIM_VAL
A wildcard value for dimensions that are unknown at a primitive creation time.
Definition: dnnl_types.h:1305
#define DNNL_RUNTIME_F32_VAL
A wildcard value for floating point values that are unknown at a primitive creation time.
Definition: dnnl_types.h:1322
#define DNNL_ARG_ATTR_OUTPUT_SCALES
Output scaling factors provided at execution time.
Definition: dnnl_types.h:2446
#define DNNL_ARG_DST
A special mnemonic for destination argument for primitives that have a single destination.
Definition: dnnl_types.h:2307
#define DNNL_ARG_SRC
A special mnemonic for source argument for primitives that have a single source.
Definition: dnnl_types.h:2283
#define DNNL_ARG_WEIGHTS
A special mnemonic for primitives that have a single weights argument.
Definition: dnnl_types.h:2330
@ matmul_d
matmul descriptor
oneDNN namespace
Definition: dnnl.hpp:74