00001
#include "ShOptimizations.hpp"
00002
#include <map>
00003
#include <set>
00004
#include <utility>
00005
#include "ShBitSet.hpp"
00006
#include "ShCtrlGraph.hpp"
00007
#include "ShDebug.hpp"
00008
#include "ShEvaluate.hpp"
00009
#include "ShContext.hpp"
00010
#include "ShSyntax.hpp"
00011
#include <sstream>
00012
#include <fstream>
00013
00014
namespace {
00015
using namespace SH;
00016
00017
00018
00019
struct ReachingDefs {
00020 ReachingDefs()
00021 : defsize(0)
00022 {
00023 }
00024
00025
struct Definition {
00026 Definition(ShStatement* stmt,
00027
const ShCtrlGraphNodePtr& node,
00028
int offset,
00029 ShBitSet disable_mask)
00030 : stmt(stmt), node(node), offset(offset),
00031 disable_mask(disable_mask)
00032 {
00033 }
00034
00035 ShStatement* stmt;
00036 ShCtrlGraphNodePtr node;
00037
int offset;
00038
00039
00040
00041
00042 ShBitSet disable_mask;
00043 };
00044
00045
typedef std::map<ShCtrlGraphNodePtr, int> SizeMap;
00046
typedef std::map<ShCtrlGraphNodePtr, ShBitSet> ReachingMap;
00047
00048 std::vector<Definition> defs;
00049
int defsize;
00050
00051 ReachingMap gen, prsv;
00052 ReachingMap rchin;
00053 };
00054
00055
struct DefFinder {
00056 DefFinder(ReachingDefs& r)
00057 : r(r), offset(0)
00058 {
00059 }
00060
00061
void operator()(ShCtrlGraphNodePtr node)
00062 {
00063
if (!node)
return;
00064 ShBasicBlockPtr block = node->block;
00065
if (!block)
return;
00066
00067
00068
00069 std::map<ShVariableNodePtr, ShBitSet> disable_map;
00070
00071
00072 ShBasicBlock::ShStmtList::iterator I = block->end();
00073
while (1) {
00074
if (I == block->begin())
break;
00075 --I;
00076
if (I->op !=
SH_OP_KIL && I->op !=
SH_OP_OPTBRA && I->dest.node()->kind() == SH_TEMP) {
00077
00078
00079
if (disable_map.find(I->dest.node()) == disable_map.end()) {
00080 disable_map[I->dest.node()] = ShBitSet(I->dest.node()->size());
00081 }
00082
00083
00084
00085 ShBitSet defn_map(I->dest.node()->size());
00086
for (
int i = 0; i < I->dest.size(); i++) defn_map[I->dest.swizzle()[i]] =
true;
00087
00088
00089
00090
if ((defn_map & disable_map[I->dest.node()]) == defn_map)
continue;
00091
00092
00093 r.defs.push_back(ReachingDefs::Definition(&(*I), node, offset,
00094 disable_map[I->dest.node()]));
00095 offset += I->dest.size();
00096 r.defsize += I->dest.size();
00097
00098
00099 disable_map[I->dest.node()] |= defn_map;
00100 }
00101 }
00102 }
00103
00104 ReachingDefs& r;
00105
int offset;
00106 };
00107
00108
struct InitRch {
00109 InitRch(ReachingDefs& r)
00110 : r(r)
00111 {
00112 }
00113
00114
void operator()(ShCtrlGraphNodePtr node)
00115 {
00116
if (!node)
return;
00117
00118 r.rchin[node] = ShBitSet(r.defsize);
00119 r.gen[node] = ShBitSet(r.defsize);
00120 r.prsv[node] = ~ShBitSet(r.defsize);
00121
00122 ShBasicBlockPtr block = node->block;
00123
if (!block)
return;
00124
00125
00126
for (
unsigned int i = 0; i < r.defs.size(); i++) {
00127
if (r.defs[i].node == node) {
00128
for (
int j = 0; j < r.defs[i].stmt->dest.size(); j++) {
00129
if (!r.defs[i].disable_mask[j]) {
00130 r.gen[node][r.defs[i].offset + j] =
true;
00131 }
00132 }
00133 }
00134 }
00135
00136
00137
for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00138
if (I->op ==
SH_OP_KIL
00139 || I->op ==
SH_OP_OPTBRA
00140 || I->dest.node()->kind() != SH_TEMP)
continue;
00141
for (
unsigned int i = 0; i < r.defs.size(); i++) {
00142
if (r.defs[i].stmt->dest.node() != I->dest.node())
continue;
00143
00144
for (
int j = 0; j < I->dest.size(); ++j) {
00145
for (
int k = 0; k < r.defs[i].stmt->dest.size(); ++k) {
00146
if (r.defs[i].stmt->dest.swizzle()[k] == I->dest.swizzle()[j]) {
00147 r.prsv[node][r.defs[i].offset + k] =
false;
00148 }
00149 }
00150 }
00151 }
00152 }
00153 }
00154
00155 ReachingDefs& r;
00156 };
00157
00158
struct IterateRch {
00159 IterateRch(ReachingDefs& r,
bool& changed)
00160 : r(r), changed(changed)
00161 {
00162 }
00163
00164
void operator()(
const ShCtrlGraphNodePtr& node)
00165 {
00166
if (!node)
return;
00167 SH_DEBUG_ASSERT(r.rchin.find(node) != r.rchin.end());
00168 ShBitSet newRchIn(r.defsize);
00169
00170
for (ShCtrlGraphNode::ShPredList::iterator I = node->predecessors.begin();
00171 I != node->predecessors.end(); ++I) {
00172 SH_DEBUG_ASSERT(r.gen.find(*I) != r.gen.end());
00173 SH_DEBUG_ASSERT(r.prsv.find(*I) != r.prsv.end());
00174 SH_DEBUG_ASSERT(r.rchin.find(*I) != r.rchin.end());
00175
00176 newRchIn |= (r.gen[*I] | (r.rchin[*I] & r.prsv[*I]));
00177 }
00178
if (newRchIn != r.rchin[node]) {
00179 r.rchin[node] = newRchIn;
00180 changed =
true;
00181 }
00182 }
00183
00184 ReachingDefs& r;
00185
bool& changed;
00186 };
00187
00188
struct UdDuBuilder {
00189 UdDuBuilder(ReachingDefs& r)
00190 : r(r)
00191 {
00192 }
00193
00194
struct TupleElement {
00195 TupleElement(
const ShVariableNodePtr& node,
int index)
00196 : node(node), index(index)
00197 {
00198 }
00199
00200
bool operator<(
const TupleElement& other)
const
00201
{
00202
if (node < other.node)
return true;
00203
if (node == other.node)
return index < other.index;
00204
return false;
00205 }
00206
00207 ShVariableNodePtr node;
00208
int index;
00209 };
00210
00211
void operator()(ShCtrlGraphNodePtr node) {
00212
typedef std::set<ValueTracking::Def> DefSet;
00213
typedef std::map<TupleElement, DefSet> DefMap;
00214
00215
00216
00217 DefMap defs;
00218
00219
if (!node)
return;
00220 ShBasicBlockPtr block = node->block;
00221
if (!block)
return;
00222
00223
00224
00225
00226
00227
for (std::size_t i = 0; i < r.defs.size(); i++) {
00228
for (
int j = 0; j < r.defs[i].stmt->dest().size(); j++) {
00229
if (r.rchin[node][r.defs[i].offset + j]) {
00230 ValueTracking::Def def(r.defs[i].stmt, j);
00231 defs[TupleElement(r.defs[i].stmt->dest.node(),
00232 r.defs[i].stmt->dest.swizzle()[j])].insert(def);
00233 }
00234 }
00235 }
00236
00237
00238
for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00239
00240
00241
for (
int j = 0; j <
opInfo[I->op].arity; j++) {
00242
if (I->src[j].node()->kind() == SH_TEMP) {
00243 ValueTracking* vt = I->template get_info<ValueTracking>();
00244
if (!vt) {
00245 vt =
new ValueTracking(&(*I));
00246 I->add_info(vt);
00247 }
00248
for (
int i = 0; i < I->src[j].size(); i++) {
00249
const DefSet& ds = defs[TupleElement(I->src[j].node(), I->src[j].swizzle()[i])];
00250
00251 vt->defs[j][i] = ds;
00252
for (DefSet::const_iterator J = ds.begin(); J != ds.end(); J++) {
00253 ValueTracking* ut = J->stmt->template get_info<ValueTracking>();
00254
if (!ut) {
00255 ut =
new ValueTracking(J->stmt);
00256 J->stmt->add_info(ut);
00257 }
00258 ut->uses[J->index].insert(ValueTracking::Use(&(*I), j, i));
00259 }
00260 }
00261 }
00262 }
00263
00264
00265
for (
int i = 0; i < I->dest.size(); ++i) {
00266 defs[TupleElement(I->dest.node(), I->dest.swizzle()[i])].clear();
00267 ValueTracking::Def def(&(*I), i);
00268 defs[TupleElement(I->dest.node(), I->dest.swizzle()[i])].insert(def);
00269 }
00270 }
00271
00272 }
00273
00274 ReachingDefs& r;
00275 };
00276
00277
struct UdDuClearer {
00278
void operator()(ShCtrlGraphNodePtr node) {
00279
if (!node)
return;
00280 ShBasicBlockPtr block = node->block;
00281
if (!block)
return;
00282
00283
for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00284 I->template destroy_info<ValueTracking>();
00285 }
00286 }
00287 };
00288
00289
struct UdDuDumper {
00290
void operator()(ShCtrlGraphNodePtr node) {
00291
if (!node)
return;
00292 ShBasicBlockPtr block = node->block;
00293
if (!block)
return;
00294
00295
for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00296 ValueTracking* vt = I->template get_info<ValueTracking>();
00297
if (!vt) {
00298 SH_DEBUG_PRINT(*I <<
" HAS NO VALUE TRACKING");
00299
continue;
00300 }
00301
for (
int i = 0; i <
opInfo[I->op].arity; i++) {
00302
00303
int e = 0;
00304
for (ValueTracking::TupleUseDefChain::iterator E = vt->defs[i].begin();
00305 E != vt->defs[i].end(); ++E, ++e) {
00306
for (ValueTracking::UseDefChain::iterator J = E->begin(); J != E->end(); ++J) {
00307 SH_DEBUG_PRINT(
"{" << *I <<
"}.src" << i <<
"[" << e <<
"] comes from {" << *J->stmt <<
"}.dst[" << J->index <<
"]");
00308 }
00309 }
00310 }
00311
int e = 0;
00312
for (ValueTracking::TupleDefUseChain::iterator E = vt->uses.begin();
00313 E != vt->uses.end(); ++E, ++e) {
00314
for (ValueTracking::DefUseChain::iterator J = E->begin(); J != E->end(); ++J) {
00315 SH_DEBUG_PRINT(
"{" << *I <<
"}.dst[" << e <<
"] contributes to {" << *J->stmt <<
"}.src" << J->source <<
"[" << J->index <<
"]");
00316 }
00317 }
00318 }
00319 }
00320 };
00321
00322 }
00323
00324
namespace SH {
00325
00326 ValueTracking::ValueTracking(ShStatement* stmt)
00327 : uses(stmt->dest.node() ? stmt->dest.size() : 0)
00328 {
00329
for (
int i = 0; i <
opInfo[stmt->op].arity; i++) {
00330
for (
int j = 0; j < (stmt->src[i].node() ? stmt->src[i].size() : 0); j++) {
00331 defs[i].push_back(std::set<Def>());
00332 }
00333 }
00334 }
00335
00336
ShStatementInfo* ValueTracking::clone()
const
00337
{
00338
return new ValueTracking(*
this);
00339 }
00340
00341
void add_value_tracking(ShProgram& p)
00342 {
00343 ReachingDefs r;
00344
00345 ShCtrlGraphPtr graph = p.node()->ctrlGraph;
00346
00347 DefFinder finder(r);
00348 graph->dfs(finder);
00349
00350 InitRch init(r);
00351 graph->dfs(init);
00352
00353
bool changed;
00354 IterateRch iter(r, changed);
00355
do {
00356 changed =
false;
00357 graph->dfs(iter);
00358 }
while (changed);
00359
00360
#ifdef SH_DEBUG_OPTIMIZER
00361
SH_DEBUG_PRINT(
"Dumping Reaching Defs");
00362 SH_DEBUG_PRINT(
"defsize = " << r.defsize);
00363 SH_DEBUG_PRINT(
"defs.size() = " << r.defs.size());
00364
for(
unsigned int i = 0; i < r.defs.size(); ++i) {
00365 SH_DEBUG_PRINT(
" stmt[" << i <<
"]: " << *(r.defs[i].stmt));
00366 SH_DEBUG_PRINT(
" node[" << i <<
"]: " << r.defs[i].node.object());
00367 SH_DEBUG_PRINT(
"offset[" << i <<
"]: " << r.defs[i].offset);
00368 }
00369 std::cerr << std::endl;
00370
00371
for (ReachingDefs::ReachingMap::const_iterator I = r.rchin.begin(); I != r.rchin.end(); ++I) {
00372 ShCtrlGraphNodePtr node = I->first;
00373 SH_DEBUG_PRINT(
" rchin[" << node.object() <<
"]: " << I->second);
00374 SH_DEBUG_PRINT(
" gen[" << node.object() <<
"]: " << r.gen[I->first]);
00375 SH_DEBUG_PRINT(
" prsv[" << node.object() <<
"]: " << r.prsv[I->first]);
00376 std::cerr << std::endl;
00377 }
00378
#endif
00379
00380 UdDuClearer clearer;
00381 graph->dfs(clearer);
00382
00383 UdDuBuilder builder(r);
00384 graph->dfs(builder);
00385
00386
#ifdef SH_DEBUG_OPTIMIZER
00387
UdDuDumper dumper;
00388 graph->dfs(dumper);
00389
#endif
00390
}
00391
00392
00393
00394 }