Commit 05952984 by Thierry Moreau Committed by Tianqi Chen

copy intrinsic now can include typecast (#855)

parent bf71582c
...@@ -51,12 +51,17 @@ class CopyIntrinInjector : public IRMutator { ...@@ -51,12 +51,17 @@ class CopyIntrinInjector : public IRMutator {
const Store* store = body.as<Store>(); const Store* store = body.as<Store>();
if (store == nullptr) return false; if (store == nullptr) return false;
const Select* select = store->value.as<Select>(); const Select* select = store->value.as<Select>();
const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>(); const Load* load = store->value.as<Load>();
// for now only support true condition matching // for now only support true condition matching
if (select != nullptr) { if (select != nullptr) {
load = select->true_value.as<Load>(); load = select->true_value.as<Load>();
} }
// cast can be part of the pattern
if (cast != nullptr) {
load = cast->value.as<Load>();
}
if (load == nullptr) return false; if (load == nullptr) return false;
if (load->type.lanes() != 1) return false; if (load->type.lanes() != 1) return false;
Array<Var> loop_vars; Array<Var> loop_vars;
...@@ -114,7 +119,7 @@ class CopyIntrinInjector : public IRMutator { ...@@ -114,7 +119,7 @@ class CopyIntrinInjector : public IRMutator {
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size()); Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size());
Buffer dst = BufferNode::make( Buffer dst = BufferNode::make(
Var(store->buffer_var.node_), Var(store->buffer_var.node_),
load->type, store->value.type(),
dst_shape, dst_shape,
dst_strides, dst_strides,
store_strides[loop_vars.size()], store_strides[loop_vars.size()],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment