This example demonstrates memory format propagation, which is critical for deep learning applications performance.
This example demonstrates memory format propagation, which is critical for deep learning applications performance.
#include <iostream>
#include <sstream>
#include <string>
#include "example_utils.hpp"
void memory_format_propagation_tutorial(engine::kind engine_kind) {
stream s(eng);
const int N = 1, H = 14, W = 14, IC = 128, OC = 256, KH = 3, KW = 3;
auto conv_src_md = memory::desc({N, IC, H, W}, memory::data_type::f32,
memory::format_tag::any
);
auto conv_weights_md = memory::desc(
{OC, IC, KH, KW}, memory::data_type::f32,
memory::format_tag::any
);
auto conv_dst_md = memory::desc({N, OC, H, W}, memory::data_type::f32,
memory::format_tag::any
);
auto pool_dst_md = conv_dst_md;
auto conv_pd = convolution_forward::primitive_desc(
{prop_kind::forward_inference, algorithm::convolution_auto,
conv_src_md, conv_weights_md,
conv_dst_md,
{1, 1},
{1, 1}, {1, 1}},
eng);
auto pool_pd = pooling_forward::primitive_desc(
{prop_kind::forward_inference, algorithm::pooling_max,
conv_pd.dst_desc(), pool_dst_md,
{1, 1}, {KH, KW},
{1, 1}, {1, 1}},
eng);
auto src_mem = memory(
{{N, IC, H, W}, memory::data_type::f32, memory::format_tag::nchw},
eng);
auto weights_mem = memory({{OC, IC, KH, KW}, memory::data_type::f32,
memory::format_tag::oihw},
eng);
auto dst_mem = memory(
{{N, OC, H, W}, memory::data_type::f32, memory::format_tag::nchw},
eng);
bool need_reorder_src = conv_pd.src_desc() != src_mem.get_desc();
bool need_reorder_weights
= conv_pd.weights_desc() != weights_mem.get_desc();
bool need_reorder_dst = conv_pd.dst_desc() != dst_mem.get_desc();
auto conv_src_mem
= need_reorder_src ? memory(conv_pd.src_desc(), eng) : src_mem;
auto conv_weights_mem = need_reorder_weights
? memory(conv_pd.weights_desc(), eng)
: weights_mem;
auto conv_dst_mem = memory(conv_pd.dst_desc(), eng);
auto pool_dst_mem
= need_reorder_dst ? memory(pool_pd.dst_desc(), eng) : dst_mem;
if (need_reorder_src) {
auto reorder_src = reorder(src_mem, conv_src_mem);
reorder_src.execute(
s.wait();
}
if (need_reorder_weights) {
auto reorder_weights = reorder(weights_mem, conv_weights_mem);
reorder_weights.execute(s,
s.wait();
}
auto conv_scratchpad_mem = memory(conv_pd.scratchpad_desc(), eng);
auto conv = convolution_forward(conv_pd);
conv.execute(s,
auto pool_scratchpad_mem = memory(pool_pd.scratchpad_desc(), eng);
auto pool = pooling_forward(pool_pd);
pool.execute(
s.wait();
if (need_reorder_dst) {
auto reorder_dst = reorder(pool_dst_mem, dst_mem);
reorder_dst.execute(
s.wait();
}
}
int main(int argc, char **argv) {
return handle_example_errors(
memory_format_propagation_tutorial, parse_engine_kind(argc, argv));
}
#define DNNL_ARG_DST
A special mnemonic for destination argument for primitives that have a single destination.
Definition: dnnl_types.h:2307
#define DNNL_ARG_FROM
A special mnemonic for reorder source argument.
Definition: dnnl_types.h:2289
#define DNNL_ARG_SRC
A special mnemonic for source argument for primitives that have a single source.
Definition: dnnl_types.h:2283
#define DNNL_ARG_WEIGHTS
A special mnemonic for primitives that have a single weights argument.
Definition: dnnl_types.h:2330
#define DNNL_ARG_TO
A special mnemonic for reorder destination argument.
Definition: dnnl_types.h:2310
oneDNN namespace
Definition: dnnl.hpp:74