Commit 0936858f by Richard Sandiford Committed by Richard Sandiford

Support fused multiply-adds in fully-masked reductions

This patch adds support for fusing a conditional add or subtract
with a multiplication, so that we can use fused multiply-add and
multiply-subtract operations for fully-masked reductions.  E.g.
for SVE we vectorise:

  double res = 0.0;
  for (int i = 0; i < n; ++i)
    res += x[i] * y[i];

using a fully-masked loop in which the loop body has the form:

  res_1 = PHI<0(preheader), res_2(latch)>;
  avec = .MASK_LOAD (loop_mask, a)
  bvec = .MASK_LOAD (loop_mask, b)
  prod = avec * bvec;
  res_2 = .COND_ADD (loop_mask, res_1, prod, res_1);

where the last statement does the equivalent of:

  res_2 = loop_mask ? res_1 + prod : res_1;

(operating elementwise).  The point of the patch is to convert the last
two statements into:

  res_s = .COND_FMA (loop_mask, avec, bvec, res_1, res_1);

which is equivalent to:

  res_2 = loop_mask ? fma (avec, bvec, res_1) : res_1;

(again operating elementwise).

2018-07-12  Richard Sandiford  <richard.sandiford@linaro.org>
	    Alan Hayward  <alan.hayward@arm.com>
	    David Sherwood  <david.sherwood@arm.com>

gcc/
	* internal-fn.h (can_interpret_as_conditional_op_p): Declare.
	* internal-fn.c (can_interpret_as_conditional_op_p): New function.
	* tree-ssa-math-opts.c (convert_mult_to_fma_1): Handle conditional
	plus and minus and convert them into IFN_COND_FMA-based sequences.
	(convert_mult_to_fma): Handle conditional plus and minus.

gcc/testsuite/
	* gcc.dg/vect/vect-fma-2.c: New test.
	* gcc.target/aarch64/sve/reduc_4.c: Likewise.
	* gcc.target/aarch64/sve/reduc_6.c: Likewise.
	* gcc.target/aarch64/sve/reduc_7.c: Likewise.

Co-Authored-By: Alan Hayward <alan.hayward@arm.com>
Co-Authored-By: David Sherwood <david.sherwood@arm.com>

From-SVN: r262588
parent b41d1f6e
2018-07-12 Richard Sandiford <richard.sandiford@linaro.org> 2018-07-12 Richard Sandiford <richard.sandiford@linaro.org>
Alan Hayward <alan.hayward@arm.com>
David Sherwood <david.sherwood@arm.com>
* internal-fn.h (can_interpret_as_conditional_op_p): Declare.
* internal-fn.c (can_interpret_as_conditional_op_p): New function.
* tree-ssa-math-opts.c (convert_mult_to_fma_1): Handle conditional
plus and minus and convert them into IFN_COND_FMA-based sequences.
(convert_mult_to_fma): Handle conditional plus and minus.
2018-07-12 Richard Sandiford <richard.sandiford@linaro.org>
* doc/md.texi (cond_fma, cond_fms, cond_fnma, cond_fnms): Document. * doc/md.texi (cond_fma, cond_fms, cond_fnma, cond_fnms): Document.
* optabs.def (cond_fma_optab, cond_fms_optab, cond_fnma_optab) * optabs.def (cond_fma_optab, cond_fms_optab, cond_fnma_optab)
......
...@@ -3333,6 +3333,62 @@ get_unconditional_internal_fn (internal_fn ifn) ...@@ -3333,6 +3333,62 @@ get_unconditional_internal_fn (internal_fn ifn)
} }
} }
/* Return true if STMT can be interpreted as a conditional tree code
operation of the form:
LHS = COND ? OP (RHS1, ...) : ELSE;
operating elementwise if the operands are vectors. This includes
the case of an all-true COND, so that the operation always happens.
When returning true, set:
- *COND_OUT to the condition COND, or to NULL_TREE if the condition
is known to be all-true
- *CODE_OUT to the tree code
- OPS[I] to operand I of *CODE_OUT
- *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the
condition is known to be all true. */
bool
can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
tree_code *code_out,
tree (&ops)[3], tree *else_out)
{
if (gassign *assign = dyn_cast <gassign *> (stmt))
{
*cond_out = NULL_TREE;
*code_out = gimple_assign_rhs_code (assign);
ops[0] = gimple_assign_rhs1 (assign);
ops[1] = gimple_assign_rhs2 (assign);
ops[2] = gimple_assign_rhs3 (assign);
*else_out = NULL_TREE;
return true;
}
if (gcall *call = dyn_cast <gcall *> (stmt))
if (gimple_call_internal_p (call))
{
internal_fn ifn = gimple_call_internal_fn (call);
tree_code code = conditional_internal_fn_code (ifn);
if (code != ERROR_MARK)
{
*cond_out = gimple_call_arg (call, 0);
*code_out = code;
unsigned int nops = gimple_call_num_args (call) - 2;
for (unsigned int i = 0; i < 3; ++i)
ops[i] = i < nops ? gimple_call_arg (call, i + 1) : NULL_TREE;
*else_out = gimple_call_arg (call, nops + 1);
if (integer_truep (*cond_out))
{
*cond_out = NULL_TREE;
*else_out = NULL_TREE;
}
return true;
}
}
return false;
}
/* Return true if IFN is some form of load from memory. */ /* Return true if IFN is some form of load from memory. */
bool bool
......
...@@ -196,6 +196,9 @@ extern internal_fn get_conditional_internal_fn (tree_code); ...@@ -196,6 +196,9 @@ extern internal_fn get_conditional_internal_fn (tree_code);
extern internal_fn get_conditional_internal_fn (internal_fn); extern internal_fn get_conditional_internal_fn (internal_fn);
extern tree_code conditional_internal_fn_code (internal_fn); extern tree_code conditional_internal_fn_code (internal_fn);
extern internal_fn get_unconditional_internal_fn (internal_fn); extern internal_fn get_unconditional_internal_fn (internal_fn);
extern bool can_interpret_as_conditional_op_p (gimple *, tree *,
tree_code *, tree (&)[3],
tree *);
extern bool internal_load_fn_p (internal_fn); extern bool internal_load_fn_p (internal_fn);
extern bool internal_store_fn_p (internal_fn); extern bool internal_store_fn_p (internal_fn);
......
2018-07-12 Richard Sandiford <richard.sandiford@linaro.org> 2018-07-12 Richard Sandiford <richard.sandiford@linaro.org>
Alan Hayward <alan.hayward@arm.com>
David Sherwood <david.sherwood@arm.com>
* gcc.dg/vect/vect-fma-2.c: New test.
* gcc.target/aarch64/sve/reduc_4.c: Likewise.
* gcc.target/aarch64/sve/reduc_6.c: Likewise.
* gcc.target/aarch64/sve/reduc_7.c: Likewise.
2018-07-12 Richard Sandiford <richard.sandiford@linaro.org>
* gcc.dg/vect/vect-cond-arith-3.c: New test. * gcc.dg/vect/vect-cond-arith-3.c: New test.
* gcc.target/aarch64/sve/vcond_13.c: Likewise. * gcc.target/aarch64/sve/vcond_13.c: Likewise.
......
/* { dg-do compile } */
/* { dg-additional-options "-fdump-tree-optimized -fassociative-math -fno-trapping-math -fno-signed-zeros" } */
#include "tree-vect.h"
#define N (VECTOR_BITS * 11 / 64 + 3)
double
dot_prod (double *x, double *y)
{
double sum = 0;
for (int i = 0; i < N; ++i)
sum += x[i] * y[i];
return sum;
}
/* { dg-final { scan-tree-dump { = \.COND_FMA } "optimized" { target { vect_double && { vect_fully_masked && scalar_all_fma } } } } } */
/* { dg-do compile } */
/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
double
f (double *restrict a, double *restrict b, int *lookup)
{
double res = 0.0;
for (int i = 0; i < 512; ++i)
res += a[lookup[i]] * b[i];
return res;
}
/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 } } */
/* Check that the vector instructions are the only instructions. */
/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */
/* { dg-final { scan-assembler-not {\tfadd\t} } } */
/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */
/* { dg-final { scan-assembler-not {\tsel\t} } } */
/* { dg-do compile } */
/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
#define REDUC(TYPE) \
TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count) \
{ \
TYPE sum = 0; \
for (int i = 0; i < count; ++i) \
sum += x[i] * y[i]; \
return sum; \
}
REDUC (float)
REDUC (double)
/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 } } */
/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 } } */
/* { dg-do compile } */
/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */
#define REDUC(TYPE) \
TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count) \
{ \
TYPE sum = 0; \
for (int i = 0; i < count; ++i) \
sum -= x[i] * y[i]; \
return sum; \
}
REDUC (float)
REDUC (double)
/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 } } */
/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 } } */
...@@ -2655,7 +2655,6 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2) ...@@ -2655,7 +2655,6 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result) FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result)
{ {
gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt); gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt);
enum tree_code use_code;
tree addop, mulop1 = op1, result = mul_result; tree addop, mulop1 = op1, result = mul_result;
bool negate_p = false; bool negate_p = false;
gimple_seq seq = NULL; gimple_seq seq = NULL;
...@@ -2663,8 +2662,8 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2) ...@@ -2663,8 +2662,8 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
if (is_gimple_debug (use_stmt)) if (is_gimple_debug (use_stmt))
continue; continue;
use_code = gimple_assign_rhs_code (use_stmt); if (is_gimple_assign (use_stmt)
if (use_code == NEGATE_EXPR) && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)
{ {
result = gimple_assign_lhs (use_stmt); result = gimple_assign_lhs (use_stmt);
use_operand_p use_p; use_operand_p use_p;
...@@ -2675,22 +2674,23 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2) ...@@ -2675,22 +2674,23 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
use_stmt = neguse_stmt; use_stmt = neguse_stmt;
gsi = gsi_for_stmt (use_stmt); gsi = gsi_for_stmt (use_stmt);
use_code = gimple_assign_rhs_code (use_stmt);
negate_p = true; negate_p = true;
} }
if (gimple_assign_rhs1 (use_stmt) == result) tree cond, else_value, ops[3];
tree_code code;
if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,
ops, &else_value))
gcc_unreachable ();
addop = ops[0] == result ? ops[1] : ops[0];
if (code == MINUS_EXPR)
{ {
addop = gimple_assign_rhs2 (use_stmt); if (ops[0] == result)
/* a * b - c -> a * b + (-c) */ /* a * b - c -> a * b + (-c) */
if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)
addop = gimple_build (&seq, NEGATE_EXPR, type, addop); addop = gimple_build (&seq, NEGATE_EXPR, type, addop);
} else
else /* a - b * c -> (-b) * c + a */
{
addop = gimple_assign_rhs1 (use_stmt);
/* a - b * c -> (-b) * c + a */
if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR)
negate_p = !negate_p; negate_p = !negate_p;
} }
...@@ -2699,8 +2699,13 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2) ...@@ -2699,8 +2699,13 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree op2)
if (seq) if (seq)
gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT); gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);
fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop);
gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt)); if (cond)
fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond, mulop1,
op2, addop, else_value);
else
fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop);
gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt));
gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal (use_stmt)); gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal (use_stmt));
gsi_replace (&gsi, fma_stmt, true); gsi_replace (&gsi, fma_stmt, true);
/* Follow all SSA edges so that we generate FMS, FNMA and FNMS /* Follow all SSA edges so that we generate FMS, FNMA and FNMS
...@@ -2883,7 +2888,6 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2, ...@@ -2883,7 +2888,6 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
as an addition. */ as an addition. */
FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result) FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result)
{ {
enum tree_code use_code;
tree result = mul_result; tree result = mul_result;
bool negate_p = false; bool negate_p = false;
...@@ -2904,13 +2908,9 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2, ...@@ -2904,13 +2908,9 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
if (gimple_bb (use_stmt) != gimple_bb (mul_stmt)) if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))
return false; return false;
if (!is_gimple_assign (use_stmt))
return false;
use_code = gimple_assign_rhs_code (use_stmt);
/* A negate on the multiplication leads to FNMA. */ /* A negate on the multiplication leads to FNMA. */
if (use_code == NEGATE_EXPR) if (is_gimple_assign (use_stmt)
&& gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR)
{ {
ssa_op_iter iter; ssa_op_iter iter;
use_operand_p usep; use_operand_p usep;
...@@ -2932,17 +2932,20 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2, ...@@ -2932,17 +2932,20 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
use_stmt = neguse_stmt; use_stmt = neguse_stmt;
if (gimple_bb (use_stmt) != gimple_bb (mul_stmt)) if (gimple_bb (use_stmt) != gimple_bb (mul_stmt))
return false; return false;
if (!is_gimple_assign (use_stmt))
return false;
use_code = gimple_assign_rhs_code (use_stmt);
negate_p = true; negate_p = true;
} }
switch (use_code) tree cond, else_value, ops[3];
tree_code code;
if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, ops,
&else_value))
return false;
switch (code)
{ {
case MINUS_EXPR: case MINUS_EXPR:
if (gimple_assign_rhs2 (use_stmt) == result) if (ops[1] == result)
negate_p = !negate_p; negate_p = !negate_p;
break; break;
case PLUS_EXPR: case PLUS_EXPR:
...@@ -2952,47 +2955,50 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2, ...@@ -2952,47 +2955,50 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
return false; return false;
} }
/* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed if (cond)
by a MULT_EXPR that we'll visit later, we might be able to {
get a more profitable match with fnma. if (cond == result || else_value == result)
return false;
if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type))
return false;
}
/* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that
we'll visit later, we might be able to get a more profitable
match with fnma.
OTOH, if we don't, a negate / fma pair has likely lower latency OTOH, if we don't, a negate / fma pair has likely lower latency
that a mult / subtract pair. */ that a mult / subtract pair. */
if (use_code == MINUS_EXPR && !negate_p if (code == MINUS_EXPR
&& gimple_assign_rhs1 (use_stmt) == result && !negate_p
&& ops[0] == result
&& !direct_internal_fn_supported_p (IFN_FMS, type, opt_type) && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type)
&& direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)) && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)
&& TREE_CODE (ops[1]) == SSA_NAME
&& has_single_use (ops[1]))
{ {
tree rhs2 = gimple_assign_rhs2 (use_stmt); gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]);
if (is_gimple_assign (stmt2)
if (TREE_CODE (rhs2) == SSA_NAME) && gimple_assign_rhs_code (stmt2) == MULT_EXPR)
{ return false;
gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2);
if (has_single_use (rhs2)
&& is_gimple_assign (stmt2)
&& gimple_assign_rhs_code (stmt2) == MULT_EXPR)
return false;
}
} }
tree use_rhs1 = gimple_assign_rhs1 (use_stmt);
tree use_rhs2 = gimple_assign_rhs2 (use_stmt);
/* We can't handle a * b + a * b. */ /* We can't handle a * b + a * b. */
if (use_rhs1 == use_rhs2) if (ops[0] == ops[1])
return false; return false;
/* If deferring, make sure we are not looking at an instruction that /* If deferring, make sure we are not looking at an instruction that
wouldn't have existed if we were not. */ wouldn't have existed if we were not. */
if (state->m_deferring_p if (state->m_deferring_p
&& (state->m_mul_result_set.contains (use_rhs1) && (state->m_mul_result_set.contains (ops[0])
|| state->m_mul_result_set.contains (use_rhs2))) || state->m_mul_result_set.contains (ops[1])))
return false; return false;
if (check_defer) if (check_defer)
{ {
tree use_lhs = gimple_assign_lhs (use_stmt); tree use_lhs = gimple_get_lhs (use_stmt);
if (state->m_last_result) if (state->m_last_result)
{ {
if (use_rhs2 == state->m_last_result if (ops[1] == state->m_last_result
|| use_rhs1 == state->m_last_result) || ops[0] == state->m_last_result)
defer = true; defer = true;
else else
defer = false; defer = false;
...@@ -3001,12 +3007,12 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2, ...@@ -3001,12 +3007,12 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
{ {
gcc_checking_assert (!state->m_initial_phi); gcc_checking_assert (!state->m_initial_phi);
gphi *phi; gphi *phi;
if (use_rhs1 == result) if (ops[0] == result)
phi = result_of_phi (use_rhs2); phi = result_of_phi (ops[1]);
else else
{ {
gcc_assert (use_rhs2 == result); gcc_assert (ops[1] == result);
phi = result_of_phi (use_rhs1); phi = result_of_phi (ops[0]);
} }
if (phi) if (phi)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment