00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00020 #include <map>
00021 #include <set>
00022 #include <utility>
00023 #include <iostream>
00024 #include <sstream>
00025 #include <fstream>
00026 #include "ShOptimizations.hpp"
00027 #include "ShBitSet.hpp"
00028 #include "ShCtrlGraph.hpp"
00029 #include "ShDebug.hpp"
00030 #include "ShEvaluate.hpp"
00031 #include "ShContext.hpp"
00032 #include "ShSyntax.hpp"
00033 #include "ShProgramNode.hpp"
00034
00035
00036
00037
00038 #ifdef SH_DEBUG_OPTIMIZER
00039 #ifndef SH_DEBUG_VALUETRACK
00040 #define SH_DEBUG_VALUETRACK
00041 #endif
00042 #endif
00043 namespace {
00044 using namespace SH;
00045
00046
00047
00048 struct ReachingDefs {
00049 ReachingDefs()
00050 : defsize(0)
00051 {
00052 }
00053
00054 struct Definition {
00055 Definition(ShStatement* stmt,
00056 const ShCtrlGraphNodePtr& node,
00057 int offset,
00058 ShBitSet disable_mask)
00059 : varnode(stmt->dest.node()), stmt(stmt), node(node), offset(offset),
00060 disable_mask(disable_mask)
00061 {}
00062
00063 Definition(ShVariableNodePtr varnode,
00064 const ShCtrlGraphNodePtr& node,
00065 int offset,
00066 ShBitSet disable_mask)
00067 : varnode(varnode), stmt(0), node(node), offset(offset),
00068 disable_mask(disable_mask)
00069 {}
00070
00071 bool isInput() const
00072 {
00073 return !stmt;
00074 }
00075
00076 int size() const
00077 {
00078 if(isInput()) return varnode->size();
00079 else return stmt->dest.size();
00080 }
00081
00082
00083 int isDisabled(int i) const
00084 {
00085 return disable_mask[index(i)];
00086 }
00087
00088
00089 int off(int i) const
00090 {
00091 return offset + i;
00092 }
00093
00094
00095 int index(int i) const
00096 {
00097 if(isInput()) return i;
00098 return stmt->dest.swizzle()[i];
00099 }
00100
00101
00102 ValueTracking::Def toDef(int i) const
00103 {
00104 if(isInput()) return ValueTracking::Def(varnode, i);
00105 return ValueTracking::Def(stmt, i);
00106 }
00107
00108 friend std::ostream& operator<<(std::ostream& out, const Definition& def)
00109 {
00110 out << "{";
00111 out << " off " << def.offset;
00112 out << ", node " << def.node.object();
00113 out << ", dsbl " << def.disable_mask << ", ";
00114 if(def.isInput()) out << "input " << def.varnode->name();
00115 else out << "stmt " << *def.stmt;
00116 out << "}";
00117 return out;
00118 }
00119
00120 ShVariableNodePtr varnode;
00121 ShStatement* stmt;
00122 ShCtrlGraphNodePtr node;
00123 int offset;
00124
00125
00126
00127
00128
00129 ShBitSet disable_mask;
00130 };
00131
00132 void addDefinition(const Definition& d)
00133 {
00134
00135 defs.push_back(d);
00136 defsize += d.size();
00137 }
00138
00139 typedef std::map<ShCtrlGraphNodePtr, int> SizeMap;
00140 typedef std::map<ShCtrlGraphNodePtr, ShBitSet> ReachingMap;
00141
00142 std::vector<Definition> defs;
00143 int defsize;
00144
00145 ReachingMap gen, prsv;
00146 ReachingMap rchin;
00147 };
00148
00149 struct DefFinder {
00150 DefFinder(ReachingDefs& r, ShCtrlGraphNodePtr entry, const ShProgramNode::VarList& inputs)
00151 : entry(entry), inputs(inputs), r(r), offset(0)
00152 {
00153 }
00154
00155 void operator()(ShCtrlGraphNodePtr node)
00156 {
00157 if (!node) return;
00158 ShBasicBlockPtr block = node->block;
00159
00160
00161
00162 std::map<ShVariableNodePtr, ShBitSet> disable_map;
00163
00164 if(block) {
00165
00166 ShBasicBlock::ShStmtList::iterator I = block->end();
00167 while (1) {
00168 if (I == block->begin()) break;
00169 --I;
00170 if (!I->dest.null() && I->op != SH_OP_KIL && I->op != SH_OP_OPTBRA ) {
00171
00172
00173 if (disable_map.find(I->dest.node()) == disable_map.end()) {
00174 disable_map[I->dest.node()] = ShBitSet(I->dest.node()->size());
00175 }
00176
00177
00178
00179 ShBitSet defn_map(I->dest.node()->size());
00180 for (int i = 0; i < I->dest.size(); i++) defn_map[I->dest.swizzle()[i]] = true;
00181
00182
00183
00184 if ((defn_map & disable_map[I->dest.node()]) == defn_map) continue;
00185
00186
00187 r.addDefinition(ReachingDefs::Definition(&(*I), node, r.defsize,
00188 disable_map[I->dest.node()]));
00189
00190
00191 disable_map[I->dest.node()] |= defn_map;
00192 }
00193 }
00194 }
00195
00196
00197 if(node != entry) return;
00198 for(ShProgramNode::VarList::const_iterator I = inputs.begin();
00199 I != inputs.end(); ++I) {
00200 if(disable_map.find(*I) == disable_map.end()) {
00201 disable_map[*I] = ShBitSet((*I)->size());
00202 }
00203
00204
00205 if(disable_map[*I].full()) continue;
00206
00207 r.addDefinition(ReachingDefs::Definition(*I, node, r.defsize,
00208 disable_map[*I]));
00209 }
00210 }
00211
00212 ShCtrlGraphNodePtr entry;
00213 const ShProgramNode::VarList& inputs;
00214 ReachingDefs& r;
00215 int offset;
00216 };
00217
00218 struct InitRch {
00219 InitRch(ReachingDefs& r)
00220 : r(r)
00221 {
00222 }
00223
00224 void operator()(ShCtrlGraphNodePtr node)
00225 {
00226 if (!node) return;
00227
00228 r.rchin[node] = ShBitSet(r.defsize);
00229 r.gen[node] = ShBitSet(r.defsize);
00230 r.prsv[node] = ~ShBitSet(r.defsize);
00231
00232 ShBasicBlockPtr block = node->block;
00233
00234
00235 for (unsigned int i = 0; i < r.defs.size(); i++) {
00236 ReachingDefs::Definition &d = r.defs[i];
00237 if (d.node == node) {
00238 for (int j = 0; j < d.size(); j++) {
00239
00240
00241
00242 if (!d.isDisabled(j)) {
00243 r.gen[node][d.off(j)] = true;
00244 }
00245 }
00246 }
00247 }
00248
00249 if(!block) return;
00250
00251
00252 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00253 if (I->dest.null()
00254 || I->op == SH_OP_KIL
00255 || I->op == SH_OP_OPTBRA
00256 ) continue;
00257 for (unsigned int i = 0; i < r.defs.size(); i++) {
00258 ReachingDefs::Definition &d = r.defs[i];
00259 if (d.varnode != I->dest.node()) continue;
00260
00261 for (int j = 0; j < I->dest.size(); ++j) {
00262 for (int k = 0; k < d.size(); ++k) {
00263 if (d.index(k) == I->dest.swizzle()[j]) {
00264 r.prsv[node][d.off(k)] = false;
00265 }
00266 }
00267 }
00268 }
00269 }
00270 }
00271
00272 ReachingDefs& r;
00273 };
00274
00275 struct IterateRch {
00276 IterateRch(ReachingDefs& r, bool& changed)
00277 : r(r), changed(changed)
00278 {
00279 }
00280
00281 void operator()(const ShCtrlGraphNodePtr& node)
00282 {
00283 if (!node) return;
00284 SH_DEBUG_ASSERT(r.rchin.find(node) != r.rchin.end());
00285 ShBitSet newRchIn(r.defsize);
00286
00287 for (ShCtrlGraphNode::ShPredList::iterator I = node->predecessors.begin();
00288 I != node->predecessors.end(); ++I) {
00289 SH_DEBUG_ASSERT(r.gen.find(*I) != r.gen.end());
00290 SH_DEBUG_ASSERT(r.prsv.find(*I) != r.prsv.end());
00291 SH_DEBUG_ASSERT(r.rchin.find(*I) != r.rchin.end());
00292
00293 newRchIn |= (r.gen[*I] | (r.rchin[*I] & r.prsv[*I]));
00294 }
00295 if (newRchIn != r.rchin[node]) {
00296 r.rchin[node] = newRchIn;
00297 changed = true;
00298 }
00299 }
00300
00301 ReachingDefs& r;
00302 bool& changed;
00303 };
00304
00305
00306
00307 struct UdDuBuilder {
00308 UdDuBuilder(ReachingDefs& r, ShProgramNodePtr p)
00309 : r(r), m_exit(p->ctrlGraph->exit()), p(p),
00310 intrack(new InputValueTracking()),
00311 outtrack(new OutputValueTracking())
00312 {
00313 p->destroy_info<InputValueTracking>();
00314 p->destroy_info<OutputValueTracking>();
00315 p->add_info(intrack);
00316 p->add_info(outtrack);
00317 }
00318
00319 struct TupleElement {
00320 TupleElement(const ShVariableNodePtr& node, int index)
00321 : node(node), index(index)
00322 {
00323 }
00324
00325 bool operator<(const TupleElement& other) const
00326 {
00327 if (node < other.node) return true;
00328 if (node == other.node) return index < other.index;
00329 return false;
00330 }
00331
00332 ShVariableNodePtr node;
00333 int index;
00334 };
00335
00336 void operator()(ShCtrlGraphNodePtr node) {
00337 typedef std::set<ValueTracking::Def> DefSet;
00338 typedef std::map<TupleElement, DefSet> DefMap;
00339
00340
00341
00342 DefMap defs;
00343
00344 if (!node) return;
00345 ShBasicBlockPtr block = node->block;
00346
00347
00348
00349
00350
00351 for (std::size_t i = 0; i < r.defs.size(); i++) {
00352 for (int j = 0; j < r.defs[i].size(); j++) {
00353 if (r.rchin[node][r.defs[i].offset + j]) {
00354 ValueTracking::Def def(r.defs[i].toDef(j));
00355 defs[TupleElement(r.defs[i].varnode,
00356 r.defs[i].index(j))].insert(def);
00357 }
00358 }
00359 }
00360
00361 if(block) {
00362
00363 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00364
00365
00366 for (int j = 0; j < opInfo[I->op].arity; j++) {
00367
00368 ValueTracking* vt = I->get_info<ValueTracking>();
00369 if (!vt) {
00370 vt = new ValueTracking(&(*I));
00371 I->add_info(vt);
00372 }
00373 ShVariableNodePtr srcNode = I->src[j].node();
00374 for (int i = 0; i < I->src[j].size(); i++) {
00375 const DefSet& ds = defs[TupleElement(srcNode, I->src[j].swizzle()[i])];
00376
00377 vt->defs[j][i] = ds;
00378 ValueTracking::Use srcUse(&(*I), j, i);
00379 for (DefSet::const_iterator J = ds.begin(); J != ds.end(); J++) {
00380 switch(J->kind) {
00381 case ValueTracking::Def::INPUT:
00382 {
00383 ValueTracking::TupleDefUseChain& inputDu =
00384 intrack->inputUses[srcNode];
00385 if(inputDu.empty()) {
00386 inputDu.resize(srcNode->size());
00387 }
00388 inputDu[J->index].insert(srcUse);
00389 break;
00390 }
00391 case ValueTracking::Def::STMT:
00392 {
00393 ValueTracking* ut = J->stmt->get_info<ValueTracking>();
00394 if (!ut) {
00395 ut = new ValueTracking(J->stmt);
00396 J->stmt->add_info(ut);
00397 }
00398 ut->uses[J->index].insert(srcUse);
00399 break;
00400 }
00401 }
00402 }
00403
00404 }
00405 }
00406
00407 for (int i = 0; i < I->dest.size(); ++i) {
00408 defs[TupleElement(I->dest.node(), I->dest.swizzle()[i])].clear();
00409 ValueTracking::Def def(&(*I), i);
00410 defs[TupleElement(I->dest.node(), I->dest.swizzle()[i])].insert(def);
00411 }
00412 }
00413 }
00414
00415
00416
00417
00418
00419 if(node == m_exit) {
00420 for(DefMap::iterator D = defs.begin(); D != defs.end(); ++D) {
00421 ShVariableNodePtr node = D->first.node;
00422 ShBindingType kind = node->kind();
00423 int index = D->first.index;
00424
00425 if(kind == SH_INOUT || kind == SH_OUTPUT) {
00426 DefSet& ds = D->second;
00427
00428 ValueTracking::TupleUseDefChain& outDef = outtrack->outputDefs[node];
00429 if(outDef.empty()) {
00430 outDef.resize(node->size());
00431 }
00432 outDef[index] = ds;
00433
00434 for(DefSet::iterator S = ds.begin(); S != ds.end(); ++S) {
00435 if(S->kind != ValueTracking::Def::STMT) continue;
00436 ValueTracking* vt = S->stmt->get_info<ValueTracking>();
00437 if (!vt) {
00438 vt = new ValueTracking(S->stmt);
00439 S->stmt->add_info(vt);
00440 }
00441 SH_DEBUG_ASSERT(vt);
00442 ValueTracking::Use use(node, S->stmt->dest.swizzle()[S->index]);
00443 vt->uses[S->index].insert(use);
00444 }
00445 }
00446 }
00447 }
00448 }
00449
00450 ReachingDefs& r;
00451 ShCtrlGraphNodePtr m_exit;
00452 ShProgramNodePtr p;
00453 InputValueTracking* intrack;
00454 OutputValueTracking* outtrack;
00455 };
00456
00457 struct UdDuClearer {
00458 void operator()(ShCtrlGraphNodePtr node) {
00459 if (!node) return;
00460 ShBasicBlockPtr block = node->block;
00461 if (!block) return;
00462
00463 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00464 I->destroy_info<ValueTracking>();
00465 }
00466 }
00467 };
00468
00469 struct UdDuDumper {
00470 void operator()(ShCtrlGraphNodePtr node) {
00471 if (!node) return;
00472 ShBasicBlockPtr block = node->block;
00473 if (!block) return;
00474
00475 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00476 ValueTracking* vt = I->get_info<ValueTracking>();
00477 if (!vt) {
00478 SH_DEBUG_PRINT(*I << " HAS NO VALUE TRACKING");
00479 continue;
00480 }
00481 SH_DEBUG_PRINT("Valuetracking for " << *I);
00482 for (int i = 0; i < opInfo[I->op].arity; i++) {
00483 SH_DEBUG_PRINT(" src ud" << i << "\n" << vt->defs[i]);
00484 }
00485 SH_DEBUG_PRINT(" dest du" << vt->uses);
00486 }
00487 }
00488
00489 void operator()(ShProgramNodePtr p) {
00490 #ifdef SH_DEBUG
00491 InputValueTracking* ivt = p->get_info<InputValueTracking>();
00492 SH_DEBUG_ASSERT(ivt);
00493 SH_DEBUG_PRINT("Input Valuetracking:\n" << *ivt);
00494
00495 OutputValueTracking* ovt = p->get_info<OutputValueTracking>();
00496 SH_DEBUG_ASSERT(ovt);
00497 SH_DEBUG_PRINT("Output Valuetracking:\n" << *ovt);
00498 #endif
00499 }
00500 };
00501
00502 }
00503
00504 namespace SH {
00505
00506 ValueTracking::ValueTracking(ShStatement* stmt)
00507 : uses(stmt->dest.node() ? stmt->dest.size() : 0),
00508 defs(stmt->src.size())
00509 {
00510 #ifdef SH_DEBUG_VALUETRACK
00511 SH_DEBUG_PRINT("Adding value tracking to " << *stmt);
00512 #endif
00513 for (int i = 0; i < opInfo[stmt->op].arity; i++) {
00514 for (int j = 0; j < (stmt->src[i].node() ? stmt->src[i].size() : 0); j++) {
00515 defs[i].push_back(std::set<Def>());
00516 }
00517 }
00518 }
00519
00520 ShInfo* ValueTracking::clone() const
00521 {
00522 return new ValueTracking(*this);
00523 }
00524
00525 std::ostream& operator<<(std::ostream& out, const ValueTracking::Use& use) {
00526 if(use.kind == ValueTracking::Use::STMT) {
00527 out << "(" << *use.stmt << ").src" << use.source << "[" << use.index << "]";
00528 } else {
00529 out << "(OUTPUT " << use.node->name() << ")[" << use.index << "]";
00530 }
00531 return out;
00532 }
00533
00534 std::ostream& operator<<(std::ostream& out, const ValueTracking::TupleUseDefChain& tud)
00535 {
00536 int e = 0;
00537 for (ValueTracking::TupleUseDefChain::const_iterator E = tud.begin();
00538 E != tud.end(); ++E, ++e) {
00539 out << "[" << e << "] <- {";
00540 for (ValueTracking::UseDefChain::const_iterator J = E->begin(); J != E->end(); ++J) {
00541 if(J != E->begin()) out << ", ";
00542 out << *J;
00543 }
00544 out << "}\n";
00545 }
00546 return out;
00547 }
00548
00549 std::ostream& operator<<(std::ostream& out, const ValueTracking::Def& def) {
00550 if(def.kind == ValueTracking::Def::STMT) {
00551 out << "(" << *def.stmt << ").dst" << "[" << def.index << "]";
00552 } else {
00553 out << "(INPUT " << def.node->name() << ")[" << def.index << "]";
00554 }
00555 return out;
00556 }
00557
00558 std::ostream& operator<<(std::ostream& out, const ValueTracking::TupleDefUseChain& tdu)
00559 {
00560 int e = 0;
00561 for (ValueTracking::TupleDefUseChain::const_iterator E = tdu.begin();
00562 E != tdu.end(); ++E, ++e) {
00563 out << "[" << e << "] -> {";
00564 for (ValueTracking::DefUseChain::const_iterator J = E->begin(); J != E->end(); ++J) {
00565 if(J != E->begin()) out << ", ";
00566 out << *J;
00567 }
00568 out << "}\n";
00569 }
00570 return out;
00571 }
00572
00573 ShInfo* InputValueTracking::clone() const
00574 {
00575 return new InputValueTracking(*this);
00576 }
00577
00578
00579 std::ostream& operator<<(std::ostream& out, const InputValueTracking& ivt)
00580 {
00581 InputValueTracking::InputTupleDefUseChain::const_iterator I;
00582 for(I = ivt.inputUses.begin(); I != ivt.inputUses.end(); ++I) {
00583 ShVariableNodePtr node = I->first;
00584 const ValueTracking::TupleDefUseChain& tdu = I->second;
00585 out << node->name();
00586 out << tdu;
00587 }
00588 return out;
00589 }
00590
00591
00592 ShInfo* OutputValueTracking::clone() const
00593 {
00594 return new OutputValueTracking(*this);
00595 }
00596
00597 std::ostream& operator<<(std::ostream& out, const OutputValueTracking& ovt)
00598 {
00599 OutputValueTracking::OutputTupleUseDefChain::const_iterator O;
00600 for(O = ovt.outputDefs.begin(); O != ovt.outputDefs.end(); ++O) {
00601 ShVariableNodePtr node = O->first;
00602 const ValueTracking::TupleUseDefChain& tud = O->second;
00603 out << node->name() << tud;
00604 }
00605 return out;
00606 }
00607
00608 void add_value_tracking(ShProgram& p)
00609 {
00610 ReachingDefs r;
00611
00612 ShCtrlGraphPtr graph = p.node()->ctrlGraph;
00613
00614 DefFinder finder(r, graph->entry(), p.node()->inputs);
00615 graph->dfs(finder);
00616
00617 InitRch init(r);
00618 graph->dfs(init);
00619
00620 bool changed;
00621 IterateRch iter(r, changed);
00622 do {
00623 changed = false;
00624 graph->dfs(iter);
00625 } while (changed);
00626
00627 #ifdef SH_DEBUG_VALUETRACK
00628 SH_DEBUG_PRINT("Dumping Reaching Defs");
00629 SH_DEBUG_PRINT("defsize = " << r.defsize);
00630 SH_DEBUG_PRINT("defs.size() = " << r.defs.size());
00631 for(unsigned int i = 0; i < r.defs.size(); ++i) {
00632 SH_DEBUG_PRINT(" " << i << ": " << r.defs[i]);
00633 }
00634 std::cerr << std::endl;
00635
00636 for (ReachingDefs::ReachingMap::const_iterator I = r.rchin.begin(); I != r.rchin.end(); ++I) {
00637 ShCtrlGraphNodePtr node = I->first;
00638 SH_DEBUG_PRINT(" rchin[" << node.object() << "]: " << I->second);
00639 SH_DEBUG_PRINT(" gen[" << node.object() << "]: " << r.gen[I->first]);
00640 SH_DEBUG_PRINT(" prsv[" << node.object() << "]: " << r.prsv[I->first]);
00641 std::cerr << std::endl;
00642 }
00643 #endif
00644
00645 UdDuClearer clearer;
00646 graph->dfs(clearer);
00647
00648 UdDuBuilder builder(r, p.node());
00649 graph->dfs(builder);
00650
00651 #ifdef SH_DEBUG_VALUETRACK
00652 SH_DEBUG_PRINT("Uddu Dump");
00653 UdDuDumper dumper;
00654 graph->dfs(dumper);
00655 dumper(p.node());
00656 #endif
00657 }
00658
00659
00660
00661 }