Changes to make tensorize work. These changes also fix the previously broken test. (#3981)
* Changes to make tensorize work. These changes also fix the previously broken test. Summary: Tensorize was breaking for a few reasons. 1) Assert at: src/op/tensorize.cc:234 CHECK(is_one(e.region[j]->extent)) In some cases this cannot be proven, e.g.: expected shape=[16, 4], given region=[range(min=((ax1.outer*16)/16), ext=(((((ax1.outer*16) + 15)/16) + 1) - ax1.outer)), range(min=((k.outer*4)/4), ext=(((((k.outer*4) + 3)/4) + 1) - k.outer)), range(min=0, ext=16), range(min=0, ext=4)] The unprovable one is: ext=(((((ax1.outer*16) + 15)/16) + 1) - ax1.outer)). This can be simplified but it is not because to simplify divide, it must prove ax1.outer > 0 and since it is var it cannot. The fix for this to just find all the vars in expr in relace them with some const value. 2) Equivalence between tensorized expr and one being asked to tensorize. For example, the error would be. TVMError: Check failed: Equal(lhs, rhs): Failed to match the compute with TensorIntrin tensor_intrin's declaration provided= reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[(int16)0]), source=[(int16(data(k))*int16(kernel(((((((((k.outer.outer*64) + (k.outer.inner*2)) + k)/2)*128) + i) - (k.outer.inner*128)) - (k.outer.outer*4096)), ((((k.outer.outer*64) + (k.outer.inner*2)) + k) % 2))))], axis=[iter_var(k, range(min=0, ext=2))], where=(bool)1, value_index=0), intrin= reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[(int16)0]), source=[(int16(data(k))*int16(kernel(i, k)))], axis=[iter_var(k, range(min=0, ext=2))], where=(bool)1, value_index=0) Difference is mainly in the source part: source=[(int16(data(k))*int16(kernel(((((((((k.outer.outer*64) + (k.outer.inner*2)) + k)/2)*128) + i) - (k.outer.inner*128)) - (k.outer.outer*4096)), ((((k.outer.outer*64) + (k.outer.inner*2)) + k) % 2))))] source=[(int16(data(k))*int16(kernel(i, k)))], axis=[iter_var(k, range(min=0, ext=2))] This was not being simpifiled due to compute_intrin_iter_space (map for iter var to range) not containing leaf iter vars. 3) Here it fails with: Check failed: is_one(Simplify(value->shape[i])): Argument b_buffer shape mismatch[16, 4] vs [(((((ax1.outer*16) + 15)/16) + 1) - ax1.outer), (((((k.outer*4) + 3)/4) + 1) - k.outer), 16, 4] This is in buffer binding where it thinks expected and buffer bound shape is different. Although if we could simplify expr, this would not be the case. Test Plan: On skylake avx512 machine: python tests/python/contrib/test_gemm_acc16.py Reviewers: Subscribers: Tasks: Tags: * Implemented bounded analyzer which traverses tree and for reduce/for statements binds the bound of the analyzer. Later this is used to simplify expressions. Inspired from ir_mutator_with_analyzer Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Addressed comments. Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Added ASF header + define macro for the header file: TVM_ARITHMETIC_IR_VISITOR_WITH_ANALYZER_H_ Some lint fixes as well. * Relax the assumption that dom_map must always contain all leaf itervars. Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Disable copy constructor and move to raw ptr. Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Showing
src/arithmetic/ir_visitor_with_analyzer.h
0 → 100644
Please
register
or
sign in
to comment