00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00020 #include <map>
00021 #include <set>
00022 #include <utility>
00023 #include <iostream>
00024 #include <sstream>
00025 #include <fstream>
00026 #include "ShOptimizations.hpp"
00027 #include "ShBitSet.hpp"
00028 #include "ShCtrlGraph.hpp"
00029 #include "ShDebug.hpp"
00030 #include "ShEvaluate.hpp"
00031 #include "ShContext.hpp"
00032 #include "ShSyntax.hpp"
00033 #include "ShProgramNode.hpp"
00034
00035
00036
00037
00038 #ifdef SH_DEBUG_OPTIMIZER
00039 #ifndef SH_DEBUG_VALUETRACK
00040 #define SH_DEBUG_VALUETRACK
00041 #endif
00042 #endif
00043 namespace {
00044 using namespace SH;
00045
00046
00047
00048 struct ReachingDefs {
00049 ReachingDefs()
00050 : defsize(0)
00051 {
00052 }
00053
00054 struct Definition {
00055 Definition(ShStatement* stmt,
00056 const ShCtrlGraphNodePtr& node,
00057 int offset,
00058 ShBitSet disable_mask)
00059 : varnode(stmt->dest.node()), stmt(stmt), node(node), offset(offset),
00060 disable_mask(disable_mask)
00061 {}
00062
00063 Definition(const ShVariableNodePtr& varnode,
00064 const ShCtrlGraphNodePtr& node,
00065 int offset,
00066 ShBitSet disable_mask)
00067 : varnode(varnode), stmt(0), node(node), offset(offset),
00068 disable_mask(disable_mask)
00069 {}
00070
00071 bool isInput() const
00072 {
00073 return !stmt;
00074 }
00075
00076 int size() const
00077 {
00078 if(isInput()) return varnode->size();
00079 else return stmt->dest.size();
00080 }
00081
00082
00083 int isDisabled(int i) const
00084 {
00085 return disable_mask[index(i)];
00086 }
00087
00088
00089 int off(int i) const
00090 {
00091 return offset + i;
00092 }
00093
00094
00095 int index(int i) const
00096 {
00097 if(isInput()) return i;
00098 return stmt->dest.swizzle()[i];
00099 }
00100
00101
00102 ValueTracking::Def toDef(int i) const
00103 {
00104 if(isInput()) return ValueTracking::Def(varnode, i);
00105 return ValueTracking::Def(stmt, i);
00106 }
00107
00108 friend std::ostream& operator<<(std::ostream& out, const Definition& def)
00109 {
00110 out << "{";
00111 out << " off " << def.offset;
00112 out << ", node " << def.node.object();
00113 out << ", dsbl " << def.disable_mask << ", ";
00114 if(def.isInput()) out << "input " << def.varnode->name();
00115 else out << "stmt " << *def.stmt;
00116 out << "}";
00117 return out;
00118 }
00119
00120 ShVariableNodePtr varnode;
00121 ShStatement* stmt;
00122 ShCtrlGraphNodePtr node;
00123 int offset;
00124
00125
00126
00127
00128
00129 ShBitSet disable_mask;
00130 };
00131
00132 void addDefinition(const Definition& d)
00133 {
00134
00135 defs.push_back(d);
00136 defsize += d.size();
00137 }
00138
00139 typedef std::map<ShCtrlGraphNodePtr, int> SizeMap;
00140 typedef std::map<ShCtrlGraphNodePtr, ShBitSet> ReachingMap;
00141
00142 std::vector<Definition> defs;
00143 int defsize;
00144
00145 ReachingMap gen, prsv;
00146 ReachingMap rchin;
00147 };
00148
00149 struct DefFinder {
00150 DefFinder(ReachingDefs& r, const ShCtrlGraphNodePtr& entry, const ShProgramNode::VarList& inputs)
00151 : entry(entry), inputs(inputs), r(r), offset(0)
00152 {
00153 }
00154
00155
00156 DefFinder& operator=(DefFinder const&);
00157
00158 void operator()(const ShCtrlGraphNodePtr& node)
00159 {
00160 if (!node) return;
00161 ShBasicBlockPtr block = node->block;
00162
00163
00164
00165 std::map<ShVariableNodePtr, ShBitSet> disable_map;
00166
00167 if(block) {
00168
00169 ShBasicBlock::ShStmtList::iterator I = block->end();
00170 while (1) {
00171 if (I == block->begin()) break;
00172 --I;
00173 if (!I->dest.null() && I->op != SH_OP_KIL && I->op != SH_OP_OPTBRA ) {
00174
00175
00176 if (disable_map.find(I->dest.node()) == disable_map.end()) {
00177 disable_map[I->dest.node()] = ShBitSet(I->dest.node()->size());
00178 }
00179
00180
00181
00182 ShBitSet defn_map(I->dest.node()->size());
00183 for (int i = 0; i < I->dest.size(); i++) defn_map[I->dest.swizzle()[i]] = true;
00184
00185
00186
00187 if ((defn_map & disable_map[I->dest.node()]) == defn_map) continue;
00188
00189
00190 r.addDefinition(ReachingDefs::Definition(&(*I), node, r.defsize,
00191 disable_map[I->dest.node()]));
00192
00193
00194 disable_map[I->dest.node()] |= defn_map;
00195 }
00196 }
00197 }
00198
00199
00200 if(node != entry) return;
00201 for(ShProgramNode::VarList::const_iterator I = inputs.begin();
00202 I != inputs.end(); ++I) {
00203 if(disable_map.find(*I) == disable_map.end()) {
00204 disable_map[*I] = ShBitSet((*I)->size());
00205 }
00206
00207
00208 if(disable_map[*I].full()) continue;
00209
00210 r.addDefinition(ReachingDefs::Definition(*I, node, r.defsize,
00211 disable_map[*I]));
00212 }
00213 }
00214
00215 ShCtrlGraphNodePtr entry;
00216 const ShProgramNode::VarList& inputs;
00217 ReachingDefs& r;
00218 int offset;
00219 };
00220
00221 struct InitRch {
00222 InitRch(ReachingDefs& r)
00223 : r(r)
00224 {
00225 }
00226
00227
00228 InitRch& operator=(InitRch const&);
00229
00230 void operator()(const ShCtrlGraphNodePtr& node)
00231 {
00232 if (!node) return;
00233
00234 r.rchin[node] = ShBitSet(r.defsize);
00235 r.gen[node] = ShBitSet(r.defsize);
00236 r.prsv[node] = ~ShBitSet(r.defsize);
00237
00238 ShBasicBlockPtr block = node->block;
00239
00240
00241 for (unsigned int i = 0; i < r.defs.size(); i++) {
00242 ReachingDefs::Definition &d = r.defs[i];
00243 if (d.node == node) {
00244 for (int j = 0; j < d.size(); j++) {
00245
00246
00247
00248 if (!d.isDisabled(j)) {
00249 r.gen[node][d.off(j)] = true;
00250 }
00251 }
00252 }
00253 }
00254
00255 if(!block) return;
00256
00257
00258 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00259 if (I->dest.null()
00260 || I->op == SH_OP_KIL
00261 || I->op == SH_OP_OPTBRA
00262 ) continue;
00263 for (unsigned int i = 0; i < r.defs.size(); i++) {
00264 ReachingDefs::Definition &d = r.defs[i];
00265 if (d.varnode != I->dest.node()) continue;
00266
00267 for (int j = 0; j < I->dest.size(); ++j) {
00268 for (int k = 0; k < d.size(); ++k) {
00269 if (d.index(k) == I->dest.swizzle()[j]) {
00270 r.prsv[node][d.off(k)] = false;
00271 }
00272 }
00273 }
00274 }
00275 }
00276 }
00277
00278 ReachingDefs& r;
00279 };
00280
00281 struct IterateRch {
00282 IterateRch(ReachingDefs& r, bool& changed)
00283 : r(r), changed(changed)
00284 {
00285 }
00286
00287
00288 IterateRch& operator=(IterateRch const&);
00289
00290 void operator()(const ShCtrlGraphNodePtr& node)
00291 {
00292 if (!node) return;
00293 SH_DEBUG_ASSERT(r.rchin.find(node) != r.rchin.end());
00294 ShBitSet newRchIn(r.defsize);
00295
00296 for (ShCtrlGraphNode::ShPredList::iterator I = node->predecessors.begin();
00297 I != node->predecessors.end(); ++I) {
00298 SH_DEBUG_ASSERT(r.gen.find(*I) != r.gen.end());
00299 SH_DEBUG_ASSERT(r.prsv.find(*I) != r.prsv.end());
00300 SH_DEBUG_ASSERT(r.rchin.find(*I) != r.rchin.end());
00301
00302 newRchIn |= (r.gen[*I] | (r.rchin[*I] & r.prsv[*I]));
00303 }
00304 if (newRchIn != r.rchin[node]) {
00305 r.rchin[node] = newRchIn;
00306 changed = true;
00307 }
00308 }
00309
00310 ReachingDefs& r;
00311 bool& changed;
00312 };
00313
00314
00315
00316 struct UdDuBuilder {
00317 UdDuBuilder(ReachingDefs& r, const ShProgramNodePtr& p)
00318 : r(r), m_exit(p->ctrlGraph->exit()), p(p),
00319 intrack(new InputValueTracking()),
00320 outtrack(new OutputValueTracking())
00321 {
00322 p->destroy_info<InputValueTracking>();
00323 p->destroy_info<OutputValueTracking>();
00324 p->add_info(intrack);
00325 p->add_info(outtrack);
00326 }
00327
00328 struct TupleElement {
00329 TupleElement(const ShVariableNodePtr& node, int index)
00330 : node(node), index(index)
00331 {
00332 }
00333
00334 bool operator<(const TupleElement& other) const
00335 {
00336 if (node < other.node) return true;
00337 if (node == other.node) return index < other.index;
00338 return false;
00339 }
00340
00341 ShVariableNodePtr node;
00342 int index;
00343 };
00344
00345
00346 UdDuBuilder& operator=(UdDuBuilder const&);
00347
00348 void operator()(const ShCtrlGraphNodePtr& node) {
00349 typedef std::set<ValueTracking::Def> DefSet;
00350 typedef std::map<TupleElement, DefSet> DefMap;
00351
00352
00353
00354 DefMap defs;
00355
00356 if (!node) return;
00357 ShBasicBlockPtr block = node->block;
00358
00359
00360
00361
00362
00363 for (std::size_t i = 0; i < r.defs.size(); i++) {
00364 for (int j = 0; j < r.defs[i].size(); j++) {
00365 if (r.rchin[node][r.defs[i].offset + j]) {
00366 ValueTracking::Def def(r.defs[i].toDef(j));
00367 defs[TupleElement(r.defs[i].varnode,
00368 r.defs[i].index(j))].insert(def);
00369 }
00370 }
00371 }
00372
00373 if(block) {
00374
00375 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00376
00377
00378 for (int j = 0; j < opInfo[I->op].arity; j++) {
00379
00380 ValueTracking* vt = I->get_info<ValueTracking>();
00381 if (!vt) {
00382 vt = new ValueTracking(&(*I));
00383 I->add_info(vt);
00384 }
00385 ShVariableNodePtr srcNode = I->src[j].node();
00386 for (int i = 0; i < I->src[j].size(); i++) {
00387 const DefSet& ds = defs[TupleElement(srcNode, I->src[j].swizzle()[i])];
00388
00389 vt->defs[j][i] = ds;
00390 ValueTracking::Use srcUse(&(*I), j, i);
00391 for (DefSet::const_iterator J = ds.begin(); J != ds.end(); J++) {
00392 switch(J->kind) {
00393 case ValueTracking::Def::INPUT:
00394 {
00395 ValueTracking::TupleDefUseChain& inputDu =
00396 intrack->inputUses[srcNode];
00397 if(inputDu.empty()) {
00398 inputDu.resize(srcNode->size());
00399 }
00400 inputDu[J->index].insert(srcUse);
00401 break;
00402 }
00403 case ValueTracking::Def::STMT:
00404 {
00405 ValueTracking* ut = J->stmt->get_info<ValueTracking>();
00406 if (!ut) {
00407 ut = new ValueTracking(J->stmt);
00408 J->stmt->add_info(ut);
00409 }
00410 ut->uses[J->index].insert(srcUse);
00411 break;
00412 }
00413 }
00414 }
00415
00416 }
00417 }
00418
00419 for (int i = 0; i < I->dest.size(); ++i) {
00420 defs[TupleElement(I->dest.node(), I->dest.swizzle()[i])].clear();
00421 ValueTracking::Def def(&(*I), i);
00422 defs[TupleElement(I->dest.node(), I->dest.swizzle()[i])].insert(def);
00423 }
00424 }
00425 }
00426
00427
00428
00429
00430
00431 if(node == m_exit) {
00432 for(DefMap::iterator D = defs.begin(); D != defs.end(); ++D) {
00433 ShVariableNodePtr node = D->first.node;
00434 ShBindingType kind = node->kind();
00435 int index = D->first.index;
00436
00437 if(kind == SH_INOUT || kind == SH_OUTPUT) {
00438 DefSet& ds = D->second;
00439
00440 ValueTracking::TupleUseDefChain& outDef = outtrack->outputDefs[node];
00441 if(outDef.empty()) {
00442 outDef.resize(node->size());
00443 }
00444 outDef[index] = ds;
00445
00446 for(DefSet::iterator S = ds.begin(); S != ds.end(); ++S) {
00447 if(S->kind != ValueTracking::Def::STMT) continue;
00448 ValueTracking* vt = S->stmt->get_info<ValueTracking>();
00449 if (!vt) {
00450 vt = new ValueTracking(S->stmt);
00451 S->stmt->add_info(vt);
00452 }
00453 SH_DEBUG_ASSERT(vt);
00454 ValueTracking::Use use(node, S->stmt->dest.swizzle()[S->index]);
00455 vt->uses[S->index].insert(use);
00456 }
00457 }
00458 }
00459 }
00460 }
00461
00462 ReachingDefs& r;
00463 ShCtrlGraphNodePtr m_exit;
00464 ShProgramNodePtr p;
00465 InputValueTracking* intrack;
00466 OutputValueTracking* outtrack;
00467 };
00468
00469 struct UdDuClearer {
00470 void operator()(const ShCtrlGraphNodePtr& node) {
00471 if (!node) return;
00472 ShBasicBlockPtr block = node->block;
00473 if (!block) return;
00474
00475 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00476 I->destroy_info<ValueTracking>();
00477 }
00478 }
00479 };
00480
00481 struct UdDuDumper {
00482 void operator()(const ShCtrlGraphNodePtr& node) {
00483 if (!node) return;
00484 ShBasicBlockPtr block = node->block;
00485 if (!block) return;
00486
00487 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00488 ValueTracking* vt = I->get_info<ValueTracking>();
00489 if (!vt) {
00490 SH_DEBUG_PRINT(*I << " HAS NO VALUE TRACKING");
00491 continue;
00492 }
00493 SH_DEBUG_PRINT("Valuetracking for " << *I);
00494 for (int i = 0; i < opInfo[I->op].arity; i++) {
00495 SH_DEBUG_PRINT(" src ud" << i << "\n" << vt->defs[i]);
00496 }
00497 SH_DEBUG_PRINT(" dest du" << vt->uses);
00498 }
00499 }
00500
00501 void operator()(const ShProgramNodePtr& p) {
00502 #ifdef SH_DEBUG
00503 InputValueTracking* ivt = p->get_info<InputValueTracking>();
00504 SH_DEBUG_ASSERT(ivt);
00505 SH_DEBUG_PRINT("Input Valuetracking:\n" << *ivt);
00506
00507 OutputValueTracking* ovt = p->get_info<OutputValueTracking>();
00508 SH_DEBUG_ASSERT(ovt);
00509 SH_DEBUG_PRINT("Output Valuetracking:\n" << *ovt);
00510 #endif
00511 }
00512 };
00513
00514 }
00515
00516 namespace SH {
00517
00518 ValueTracking::ValueTracking(ShStatement* stmt)
00519 : uses(stmt->dest.node() ? stmt->dest.size() : 0),
00520 defs(stmt->src.size())
00521 {
00522 #ifdef SH_DEBUG_VALUETRACK
00523 SH_DEBUG_PRINT("Adding value tracking to " << *stmt);
00524 #endif
00525 for (int i = 0; i < opInfo[stmt->op].arity; i++) {
00526 for (int j = 0; j < (stmt->src[i].node() ? stmt->src[i].size() : 0); j++) {
00527 defs[i].push_back(std::set<Def>());
00528 }
00529 }
00530 }
00531
00532 ShInfo* ValueTracking::clone() const
00533 {
00534 return new ValueTracking(*this);
00535 }
00536
00537 std::ostream& operator<<(std::ostream& out, const ValueTracking::Use& use) {
00538 if(use.kind == ValueTracking::Use::STMT) {
00539 out << "(" << *use.stmt << ").src" << use.source << "[" << use.index << "]";
00540 } else {
00541 out << "(OUTPUT " << use.node->name() << ")[" << use.index << "]";
00542 }
00543 return out;
00544 }
00545
00546 std::ostream& operator<<(std::ostream& out, const ValueTracking::TupleUseDefChain& tud)
00547 {
00548 int e = 0;
00549 for (ValueTracking::TupleUseDefChain::const_iterator E = tud.begin();
00550 E != tud.end(); ++E, ++e) {
00551 out << "[" << e << "] <- {";
00552 for (ValueTracking::UseDefChain::const_iterator J = E->begin(); J != E->end(); ++J) {
00553 if(J != E->begin()) out << ", ";
00554 out << *J;
00555 }
00556 out << "}\n";
00557 }
00558 return out;
00559 }
00560
00561 std::ostream& operator<<(std::ostream& out, const ValueTracking::Def& def) {
00562 if(def.kind == ValueTracking::Def::STMT) {
00563 out << "(" << *def.stmt << ").dst" << "[" << def.index << "]";
00564 } else {
00565 out << "(INPUT " << def.node->name() << ")[" << def.index << "]";
00566 }
00567 return out;
00568 }
00569
00570 std::ostream& operator<<(std::ostream& out, const ValueTracking::TupleDefUseChain& tdu)
00571 {
00572 int e = 0;
00573 for (ValueTracking::TupleDefUseChain::const_iterator E = tdu.begin();
00574 E != tdu.end(); ++E, ++e) {
00575 out << "[" << e << "] -> {";
00576 for (ValueTracking::DefUseChain::const_iterator J = E->begin(); J != E->end(); ++J) {
00577 if(J != E->begin()) out << ", ";
00578 out << *J;
00579 }
00580 out << "}\n";
00581 }
00582 return out;
00583 }
00584
00585 ShInfo* InputValueTracking::clone() const
00586 {
00587 return new InputValueTracking(*this);
00588 }
00589
00590
00591 std::ostream& operator<<(std::ostream& out, const InputValueTracking& ivt)
00592 {
00593 InputValueTracking::InputTupleDefUseChain::const_iterator I;
00594 for(I = ivt.inputUses.begin(); I != ivt.inputUses.end(); ++I) {
00595 ShVariableNodePtr node = I->first;
00596 const ValueTracking::TupleDefUseChain& tdu = I->second;
00597 out << node->name();
00598 out << tdu;
00599 }
00600 return out;
00601 }
00602
00603
00604 ShInfo* OutputValueTracking::clone() const
00605 {
00606 return new OutputValueTracking(*this);
00607 }
00608
00609 std::ostream& operator<<(std::ostream& out, const OutputValueTracking& ovt)
00610 {
00611 OutputValueTracking::OutputTupleUseDefChain::const_iterator O;
00612 for(O = ovt.outputDefs.begin(); O != ovt.outputDefs.end(); ++O) {
00613 ShVariableNodePtr node = O->first;
00614 const ValueTracking::TupleUseDefChain& tud = O->second;
00615 out << node->name() << tud;
00616 }
00617 return out;
00618 }
00619
00620 void add_value_tracking(ShProgram& p)
00621 {
00622 ReachingDefs r;
00623
00624 ShCtrlGraphPtr graph = p.node()->ctrlGraph;
00625
00626 DefFinder finder(r, graph->entry(), p.node()->inputs);
00627 graph->dfs(finder);
00628
00629 InitRch init(r);
00630 graph->dfs(init);
00631
00632 bool changed;
00633 IterateRch iter(r, changed);
00634 do {
00635 changed = false;
00636 graph->dfs(iter);
00637 } while (changed);
00638
00639 #ifdef SH_DEBUG_VALUETRACK
00640 SH_DEBUG_PRINT("Dumping Reaching Defs");
00641 SH_DEBUG_PRINT("defsize = " << r.defsize);
00642 SH_DEBUG_PRINT("defs.size() = " << r.defs.size());
00643 for(unsigned int i = 0; i < r.defs.size(); ++i) {
00644 SH_DEBUG_PRINT(" " << i << ": " << r.defs[i]);
00645 }
00646 std::cerr << std::endl;
00647
00648 for (ReachingDefs::ReachingMap::const_iterator I = r.rchin.begin(); I != r.rchin.end(); ++I) {
00649 ShCtrlGraphNodePtr node = I->first;
00650 SH_DEBUG_PRINT(" rchin[" << node.object() << "]: " << I->second);
00651 SH_DEBUG_PRINT(" gen[" << node.object() << "]: " << r.gen[I->first]);
00652 SH_DEBUG_PRINT(" prsv[" << node.object() << "]: " << r.prsv[I->first]);
00653 std::cerr << std::endl;
00654 }
00655 #endif
00656
00657 UdDuClearer clearer;
00658 graph->dfs(clearer);
00659
00660 UdDuBuilder builder(r, p.node());
00661 graph->dfs(builder);
00662
00663 #ifdef SH_DEBUG_VALUETRACK
00664 SH_DEBUG_PRINT("Uddu Dump");
00665 UdDuDumper dumper;
00666 graph->dfs(dumper);
00667 dumper(p.node());
00668 #endif
00669 }
00670
00671
00672
00673 }