Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members | Related Pages

ShConstProp.cpp

00001 // Sh: A GPU metaprogramming language.
00002 //
00003 // Copyright 2003-2005 Serious Hack Inc.
00004 // 
00005 // This library is free software; you can redistribute it and/or
00006 // modify it under the terms of the GNU Lesser General Public
00007 // License as published by the Free Software Foundation; either
00008 // version 2.1 of the License, or (at your option) any later version.
00009 //
00010 // This library is distributed in the hope that it will be useful,
00011 // but WITHOUT ANY WARRANTY; without even the implied warranty of
00012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00013 // Lesser General Public License for more details.
00014 //
00015 // You should have received a copy of the GNU Lesser General Public
00016 // License along with this library; if not, write to the Free Software
00017 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, 
00018 // MA  02110-1301, USA
00020 #include "ShOptimizations.hpp"
00021 #include <map>
00022 #include <set>
00023 #include <utility>
00024 #include <iostream>
00025 #include "ShBitSet.hpp"
00026 #include "ShCtrlGraph.hpp"
00027 #include "ShDebug.hpp"
00028 #include "ShVariant.hpp"
00029 #include "ShEvaluate.hpp"
00030 #include "ShContext.hpp"
00031 #include "ShSyntax.hpp"
00032 #include "ShInfo.hpp"
00033 #include <sstream>
00034 #include <fstream>
00035 
00036 // Uncomment to enable constant/uniform propagation debugging (verbose!)
00037 // #define SH_DEBUG_CONSTPROP
00038 
00039 #ifdef SH_DEBUG_OPTIMIZER
00040 #ifndef SH_DEBUG_CONSTPROP
00041 #define SH_DEBUG_CONSTPROP
00042 #endif
00043 #endif
00044 
00045 namespace {
00046 
00047 using namespace SH;
00048 
00049 typedef std::queue<ValueTracking::Def> ConstWorkList;
00050 
00051 struct ConstProp : public ShInfo {
00052   ConstProp(ShStatement* stmt,
00053             ShProgramNodeCPtr prog,
00054             ConstWorkList& worklist)
00055     : stmt(stmt)
00056   {
00057     for (int i = 0; i < opInfo[stmt->op].arity ; i++) {
00058       for (int j = 0; j < stmt->src[i].size(); j++) {
00059         switch(stmt->src[i].node()->kind()) {
00060         case SH_INPUT:
00061         case SH_OUTPUT:
00062         case SH_INOUT:
00063         case SH_TEXTURE:
00064         case SH_STREAM:
00065         case SH_PALETTE:
00066           src[i].push_back(Cell(Cell::BOTTOM));
00067           break;
00068         case SH_TEMP:
00069           if (stmt->src[i].uniform()) {
00070             // Don't lift computations dependent on uniforms which
00071             // have been marked with "opt:lifting" == "never"
00072             if (stmt->src[i].meta("opt:lifting") != "never") {
00073               src[i].push_back(Cell(Cell::UNIFORM, stmt->src[i], j));
00074             } else {
00075               src[i].push_back(Cell(Cell::BOTTOM));
00076             }
00077           } else {
00078             src[i].push_back(Cell(Cell::TOP));
00079           }
00080           break;
00081         case SH_CONST:
00082           src[i].push_back(Cell(Cell::CONSTANT, stmt->src[i].getVariant(j)));
00083           break;
00084         default:
00085           SH_DEBUG_ASSERT(0 && "Invalid ShBindingType");
00086           return;
00087         }
00088       }
00089     }
00090     updateDest(worklist);
00091   }
00092 
00093   ShInfo* clone() const;
00094 
00095   int idx(int destindex, int source)
00096   {
00097     return (stmt->src[source].size() == 1 ? 0 : destindex);
00098   }
00099 
00100   void updateDest(ConstWorkList& worklist)
00101   {
00102     if(dest.empty()) {
00103       dest.resize(stmt->dest.size(), Cell(Cell::TOP)); 
00104     }
00105 
00106     // Ignore KIL, optbra, etc.
00107     if (opInfo[stmt->op].result_source == ShOperationInfo::IGNORE) return;
00108 
00109     if (stmt->op == SH_OP_ASN) {
00110       for (int i = 0; i < stmt->dest.size(); i++) {
00111         setDest(i, src[0][i], worklist); // assume src[0][i] cannot move up lattice
00112       }
00113     } else if (opInfo[stmt->op].result_source == ShOperationInfo::EXTERNAL) {
00114       // This statement never results in a constant
00115       // E.g. texture fetches, stream fetches.
00116       for (int i = 0; i < stmt->dest.size(); i++) {
00117         setDest(i, Cell(Cell::BOTTOM), worklist);
00118       }
00119     } else if (opInfo[stmt->op].result_source == ShOperationInfo::LINEAR) {
00120       // The strategy here is to ensure that 
00121       // a) whenever one src becomes bottom, dest becomes bottom
00122       // b) uniform only gets set when ALL src are uniform (because value
00123       // tracking requires it)
00124       // c) otherwise, propagate CONST state per element
00125       // @todo range (CONST may move across to UNIFORM, check that this is okay)
00126 
00127       // Consider each tuple element in turn.
00128       // Dest and sources are guaranteed to be of the same length.
00129       // Except that sources might be scalar.
00130       bool all_fields_uniform = true;
00131       bool some_field_bottom = false;
00132       for (int i = 0; i < stmt->dest.size(); i++) {
00133         bool alluniform = true;
00134         bool allconst = true;
00135         bool somebottom = false;
00136         for (int s = 0; s < opInfo[stmt->op].arity; s++) {
00137           if (src[s][idx(i,s)].state == Cell::BOTTOM) {
00138             somebottom = true;
00139           }
00140           if (src[s][idx(i,s)].state != Cell::CONSTANT) {
00141             allconst = false;
00142             if (src[s][idx(i,s)].state != Cell::UNIFORM) {
00143               alluniform = false;
00144             }
00145           }
00146         }
00147         some_field_bottom |= somebottom;
00148         if (!(alluniform && !allconst)) all_fields_uniform = false;
00149         if (allconst) {
00150           ShVariable tmpdest(new ShVariableNode(SH_CONST, 1, stmt->dest.valueType()));
00151           ShStatement eval(*stmt);
00152           eval.dest = tmpdest;
00153           for (int k = 0; k < opInfo[stmt->op].arity; k++) {
00154             ShVariantCPtr srcValue = src[k][idx(i,k)].value;
00155             ShVariable tmpsrc(new ShVariableNode(SH_CONST, 1, srcValue->valueType()));
00156             tmpsrc.setVariant(srcValue, 0);
00157             eval.src[k] = tmpsrc;
00158           }
00159           evaluate(eval);
00160           setDest(i, Cell(Cell::CONSTANT, tmpdest.getVariant(0)), worklist);
00161         } else if (somebottom) {
00162           setDest(i, Cell(Cell::BOTTOM), worklist);
00163         }
00164       } 
00165 
00166 
00167       // Because making a uniform cell based on a ConstProp requires
00168       // generating a value for the entire statement (not just one
00169       // field of the destination), we only push said cells if ALL of
00170       // the indices are uniform for ALL of their corresponding sources.
00171       if (all_fields_uniform) {
00172         for (int i = 0; i < stmt->dest.size(); i++) {
00173           setDest(i, Cell(Cell::UNIFORM, this, i), worklist);
00174         }
00175       }
00176     } else if (opInfo[stmt->op].result_source == ShOperationInfo::ALL) {
00177       // build statement ONLY if ALL elements of ALL sources are constant
00178       bool allconst = true;
00179       bool alluniform = true; // all statements are either uniform or constant
00180       bool somebottom = false;
00181       for (int s = 0; s < opInfo[stmt->op].arity && !somebottom; s++) {
00182         for (unsigned int k = 0; k < src[s].size(); k++) {
00183           if(src[s][k].state == Cell::BOTTOM) {
00184             somebottom = true;
00185           }
00186           if (src[s][k].state != Cell::CONSTANT) {
00187             allconst = false;
00188             if (src[s][k].state != Cell::UNIFORM) {
00189               alluniform = false;
00190             }
00191           }
00192         }
00193       }
00194       if (allconst) { 
00195         ShVariable tmpdest(new ShVariableNode(SH_CONST, stmt->dest.size(), stmt->dest.valueType()));
00196         ShStatement eval(*stmt);
00197         eval.dest = tmpdest;
00198         for (int i = 0; i < opInfo[stmt->op].arity; i++) {
00199           SH_DEBUG_ASSERT(src[i][0].value); // @todo type DEBUGGING
00200           ShValueType srcValueType = src[i][0].value->valueType(); 
00201           ShVariable tmpsrc(new ShVariableNode(SH_CONST, stmt->src[i].size(), srcValueType));
00202           for (int j = 0; j < stmt->src[i].size(); j++) {
00203             tmpsrc.setVariant(src[i][j].value, j);
00204           }
00205           eval.src[i] = tmpsrc;
00206         }
00207         evaluate(eval);
00208         for (int i = 0; i < stmt->dest.size(); i++) {
00209           setDest(i, Cell(Cell::CONSTANT, tmpdest.getVariant(i)), worklist);
00210         }
00211       } else if (alluniform) {
00212         for (int i = 0; i < stmt->dest.size(); i++) {
00213           setDest(i, Cell(Cell::UNIFORM, this, i), worklist);
00214         }
00215       } else if (somebottom) {
00216         for (int i = 0; i < stmt->dest.size(); i++) {
00217           setDest(i, Cell(Cell::BOTTOM), worklist);
00218         }
00219       }
00220     } else {
00221       SH_DEBUG_ASSERT(0 && "Invalid result source type");
00222     }
00223   }
00224 
00225   typedef int ValueNum;
00226 
00227   struct Uniform {
00228     Uniform()
00229       : constant(false),
00230         valuenum(-1)
00231     {
00232     }
00233 
00234     // @todo type...this is my current understanding:
00235     // May be constant or if !constval, value is not
00236     // known to be constant
00237     Uniform(ShVariantCPtr cval)
00238       : constant(true),
00239         constval(cval ? cval->get() : ShVariantPtr(0))
00240     {
00241     }
00242     
00243     Uniform(int valuenum, int index, bool neg)
00244       : constant(false),
00245         valuenum(valuenum), index(index), neg(neg)
00246     {
00247     }
00248 
00249     bool operator==(const Uniform& other) const
00250     {
00251       if (constant != other.constant) return false;
00252 
00253       if (constant) {
00254         // @todo type
00255         if(!constval) return false;
00256         // Check with Stefanus whether this modification is correct
00257     //    SH_DEBUG_ASSERT(constval); // @todo type debugging
00258         return constval->equals(other.constval);
00259       } else {
00260         if (valuenum != other.valuenum) return false;
00261         if (index != other.index) return false;
00262         if (neg != other.neg) return false;
00263         return true;
00264       }
00265     }
00266 
00267     ShValueType valueType() const 
00268     {
00269       if(constant) return constval->valueType();
00270       return Value::get(valuenum)->valueType();
00271     }
00272 
00273     bool operator!=(const Uniform& other) const
00274     {
00275       return !(*this == other);
00276     }
00277 
00278 
00279     bool constant;
00280 
00281     ValueNum valuenum;
00282     int index;
00283     bool neg;
00284 
00285     ShVariantPtr constval;
00286   };
00287 
00288   class Value {
00289   public:
00290     enum Type {
00291       NODE,
00292       STMT
00293     };
00294 
00295     Type type;
00296     
00297     // Only for type == NODE:
00298     ShVariableNodePtr node;
00299 
00300     // Only for type == STMT:
00301     ShOperation op;
00302     int destsize;
00303     ShValueType destValueType;
00304     std::vector<Uniform> src[3];
00305 
00306     static void clear()
00307     {
00308       m_values.clear();
00309     }
00310     
00311     static ValueNum lookup(const ShVariableNodePtr& node)
00312     {
00313       for (std::size_t i = 0; i < m_values.size(); i++) {
00314         if (m_values[i]->type == NODE && m_values[i]->node == node) return i;
00315       }
00316       m_values.push_back(new Value(node)); // FIXME: Never gets deleted
00317       return m_values.size() - 1;
00318     }
00319 
00320     bool operator==(const Value& other) const
00321     {
00322       if (type != other.type) return false;
00323       if (type == NODE) {
00324         return node == other.node;
00325       } else if (type == STMT) {
00326         if (op != other.op || destsize != other.destsize || destValueType != other.destValueType) return false;
00327         for (int i = 0; i < opInfo[op].arity; i++) {
00328           if (src[i].size() != other.src[i].size()) return false;
00329           for (std::size_t j = 0; j < src[i].size(); j++) {
00330             if (src[i][j] != other.src[i][j]) return false;
00331           }
00332         }
00333         return true;
00334       }
00335       return false;
00336     }
00337 
00338     bool operator!=(const Value& other) const
00339     {
00340       return !(*this == other);
00341     }
00342 
00343     static ValueNum lookup(ConstProp* cp)
00344     {
00345       Value* val = new Value(cp); // FIXME: Never gets deleted
00346       
00347       for (std::size_t i = 0; i < m_values.size(); i++) {
00348         if (m_values[i]->type != STMT) continue;
00349         if (m_values[i]->op != cp->stmt->op) continue;
00350 
00351         if (*val == *m_values[i]) {
00352           delete val;
00353           return i;
00354         }
00355       }
00356       m_values.push_back(val);
00357       return m_values.size() - 1;
00358     }
00359 
00360     ShValueType valueType() {
00361       if(type == NODE) {
00362         return node->valueType();
00363       } 
00364       return destValueType;
00365     }
00366 
00367     static Value* get(ValueNum n)
00368     {
00369       return m_values[n];
00370     }
00371 
00372     static void dump(std::ostream& out);
00373 
00374     std::string name() const
00375     {
00376       if(type == NODE) return node->name();
00377       else return ""; 
00378     }
00379 
00380   private:
00381     Value(const ShVariableNodePtr& node)
00382       : type(NODE), node(node)
00383     {
00384     }
00385 
00386     Value(ConstProp* cp)
00387       : type(STMT), node(0), op(cp->stmt->op), destsize(cp->stmt->dest.size()), destValueType(cp->stmt->dest.valueType())
00388     {
00389       for (int i = 0; i < opInfo[cp->stmt->op].arity; i++) {
00390         for (std::size_t j = 0; j < cp->src[i].size(); j++) {
00391           if (cp->src[i][j].state == Cell::UNIFORM) {
00392             src[i].push_back(cp->src[i][j].uniform);
00393           } else {
00394             SH_DEBUG_ASSERT(cp->src[i][j].state == Cell::CONSTANT); // @todo type should be fixed
00395             src[i].push_back(Uniform(cp->src[i][j].value));
00396           }
00397         }
00398       }
00399     }
00400     
00401     static std::vector<Value*> m_values;
00402   };
00403   
00404   struct Cell {
00405     enum State {
00406       BOTTOM,
00407       CONSTANT,
00408       UNIFORM,
00409       TOP
00410     };
00411 
00412     // @todo comments added, but may not be correct
00413    
00414     // Construct a CONSTANT Cell with non-null value
00415     // or a TOP/BOTTOM
00416     Cell(State state, ShVariantPtr value = 0)
00417       : state(state)
00418     {
00419       if(value) this->value = value->get(); 
00420       SH_DEBUG_ASSERT(this->value || (state != CONSTANT));
00421     }
00422 
00423     // Construct a UNIFORM cell from a variable
00424     Cell(State state, const ShVariable& var, int index) 
00425       : state(state), value(0),
00426         uniform(Value::lookup(var.node()), var.swizzle()[index], var.neg())
00427     {
00428     }
00429 
00430     // Construct a UNIFORM cell as a result of some
00431     // statement where all relevant sources are uniforms/constants.
00432     Cell(State state, ConstProp* cp, int index) 
00433       : state(state), value(0),
00434         uniform(Value::lookup(cp), index, false)
00435     {
00436     }
00437 
00438     bool operator==(const Cell& other) const
00439     {
00440       if(state != other.state) return false;
00441       if(value) return value->equals(other.value);
00442       // @todo type remove debug
00443       SH_DEBUG_ASSERT(!value);
00444       return value == other.value; // null
00445     }
00446 
00447     bool operator!=(const Cell& other) const
00448     {
00449       return !(*this == other);
00450     }
00451     
00452     State state;
00453     ShVariantPtr value; // Only for state == CONSTANT
00454 
00455     Uniform uniform; // Only for state == UNIFORM
00456   };
00457 
00458   // If dest[index] != cell, sets dest[index] to cell and updates the worklist 
00459   // Caller must ensure that cell does not move up the lattice. 
00460   void setDest(int index, const Cell &cell, ConstWorkList &worklist)
00461   {
00462     if(dest[index] == cell) return; 
00463     dest[index] = cell;
00464     worklist.push(ValueTracking::Def(stmt, index));
00465   }
00466 
00467   ShStatement* stmt;
00468   std::vector<Cell> dest;
00469   std::vector<Cell> src[3];
00470 };
00471 
00472 ShInfo* ConstProp::clone() const
00473 {
00474   return new ConstProp(*this);
00475 }
00476 
00477 ConstProp::Cell meet(const ConstProp::Cell& a, const ConstProp::Cell& b)
00478 {
00479   if (a.state == ConstProp::Cell::BOTTOM || b.state == ConstProp::Cell::BOTTOM) {
00480     return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00481   }
00482   if (a.state != b.state
00483       && (a.state != ConstProp::Cell::TOP && b.state != ConstProp::Cell::TOP)) {
00484     return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00485   }
00486   // At this point either the cells are the same or one of them is
00487   // top.
00488   if (a.state == b.state) {
00489     if (a.state == ConstProp::Cell::CONSTANT) {
00490       SH_DEBUG_ASSERT(a.value); // @todo type debugging
00491       if (a.value->equals(b.value)) {
00492         return a;
00493       } else {
00494         return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00495       }
00496     }
00497     if (a.state == ConstProp::Cell::UNIFORM) {
00498       if (a.uniform == b.uniform) return a;
00499       return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00500     }
00501   }
00502 
00503   if (a.state != ConstProp::Cell::TOP) return a;
00504   if (b.state != ConstProp::Cell::TOP) return b;
00505     
00506   return ConstProp::Cell(ConstProp::Cell::TOP);
00507 }
00508 
00509 std::vector<ConstProp::Value*> ConstProp::Value::m_values = std::vector<ConstProp::Value*>();
00510 
00511 std::ostream& operator<<(std::ostream& out, const ConstProp::Uniform& uniform)
00512 {
00513   if (uniform.constant) {
00514     SH_DEBUG_ASSERT(uniform.constval); // @todo type DEBUGGING
00515     out << uniform.constval->encode();
00516   } else {
00517     if (uniform.neg) out << '-';
00518     out << "v" << uniform.valuenum << "[" << uniform.index << "]";
00519   }
00520   return out;
00521 }
00522 
00523 void ConstProp::Value::dump(std::ostream& out)
00524 {
00525   out << "--- uniform values ---" << std::endl;
00526   for (std::size_t i = 0; i < m_values.size(); i++) {
00527     out << i << ": ";
00528     if (m_values[i]->type == NODE) {
00529       out << "node " << m_values[i]->node->name() << std::endl;
00530     } else if (m_values[i]->type == STMT) {
00531       out << "stmt [" << m_values[i]->destsize << "] " << opInfo[m_values[i]->op].name << " ";
00532       for (int j = 0; j < opInfo[m_values[i]->op].arity; j++) {
00533         if (j) out << ", ";
00534         for (std::vector<Uniform>::iterator U = m_values[i]->src[j].begin();
00535              U != m_values[i]->src[j].end(); ++U) {
00536           out << *U;
00537         }
00538       }
00539       out << ";" << std::endl;
00540     }
00541   }
00542 }
00543 
00544 std::ostream& operator<<(std::ostream& out, const ConstProp::Cell& cell)
00545 {
00546   switch(cell.state) {
00547   case ConstProp::Cell::BOTTOM:
00548     out << "[bot]";
00549     break;
00550   case ConstProp::Cell::CONSTANT:
00551     out << "[" << cell.value->encode() << "]";
00552     break;
00553   case ConstProp::Cell::TOP:
00554     out << "[top]";
00555     break;
00556   case ConstProp::Cell::UNIFORM:
00557     out << "<" << cell.uniform << ">";
00558     break;
00559   }
00560   return out;
00561 }
00562 
00563 struct InitConstProp {
00564   InitConstProp(ShProgramNodeCPtr prog, ConstWorkList& worklist)
00565     : prog(prog), worklist(worklist)
00566   {
00567   }
00568 
00569   void operator()(const ShCtrlGraphNodePtr& node)
00570   {
00571     if (!node) return;
00572     ShBasicBlockPtr block = node->block;
00573     if (!block) return;
00574     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00575       I->destroy_info<ConstProp>();
00576       ConstProp* cp = new ConstProp(&(*I), prog, worklist);
00577       I->add_info(cp);
00578     }
00579   }
00580 
00581   ShProgramNodeCPtr prog;
00582   ConstWorkList& worklist;
00583 };
00584 
00585 struct DumpConstProp {
00586   void operator()(const ShCtrlGraphNodePtr& node)
00587   {
00588     if (!node) return;
00589     ShBasicBlockPtr block = node->block;
00590     if (!block) return;
00591     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00592       std::cerr << "{" << *I << "} --- ";
00593       ConstProp* cp = I->get_info<ConstProp>();
00594 
00595       if (!cp) {
00596         std::cerr << "NO CP INFORMATION" << std::endl;
00597         continue;
00598       }
00599 
00600       std::cerr << "dest = {";
00601       for (std::size_t i = 0; i < cp->dest.size(); i++) {
00602         std::cerr << cp->dest[i];
00603       }
00604       std::cerr << "}; ";
00605       for (int s = 0; s < opInfo[I->op].arity; s++) {
00606         if (s) std::cerr << ", ";
00607         std::cerr << "src" << s << " = {";
00608         for (std::size_t i = 0; i < cp->src[s].size(); i++) {
00609           std::cerr << cp->src[s][i];
00610         }
00611         std::cerr << "}";
00612       }
00613       std::cerr << std::endl;
00614     }
00615     
00616   }
00617 };
00618 
00619 struct FinishConstProp
00620 {
00621   FinishConstProp(bool lift_uniforms)
00622     : lift_uniforms(lift_uniforms)
00623   {
00624   }
00625   
00626   void operator()(const ShCtrlGraphNodePtr& node) {
00627     if (!node) return;
00628     ShBasicBlockPtr block = node->block;
00629     if (!block) return;
00630     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00631       ConstProp* cp = I->get_info<ConstProp>();
00632 
00633       if (!cp) continue;
00634 
00635       if (!cp->dest.empty()) {
00636         // if all dest fields are constants, replace this with a
00637         // constant assignment
00638 
00639         if (I->op != SH_OP_ASN || I->src[0].node()->kind() != SH_CONST) {
00640           bool allconst = true;
00641           for (int i = 0; i < I->dest.size(); i++) {
00642             if (cp->dest[i].state != ConstProp::Cell::CONSTANT) {
00643               allconst = false;
00644               break;
00645             }
00646           }
00647           if (allconst) {
00648             SH_DEBUG_ASSERT(cp->dest[0].value); // @todo type debugging
00649             ShValueType destValueType = cp->dest[0].value->valueType(); 
00650             ShVariable newconst(new ShVariableNode(SH_CONST, I->dest.size(), destValueType));
00651             for(int i = 0; i < I->dest.size(); ++i) {
00652               newconst.setVariant(cp->dest[i].value, i);
00653             }
00654 #ifdef SH_DEBUG_CONSTPROP
00655             SH_DEBUG_PRINT("Replaced {" << *I << "} with " << newconst);
00656 #endif
00657             *I = ShStatement(I->dest, SH_OP_ASN, newconst);
00658           } else {
00659             // otherwise, do the same for each source field.
00660             for (int s = 0; s < opInfo[I->op].arity; s++) {
00661               if (I->src[s].node()->kind() == SH_CONST) continue;
00662             
00663               ShValueType srcValueType = I->src[s].valueType();
00664               ShVariable newconst(new ShVariableNode(SH_CONST, I->src[s].size(), srcValueType));
00665               bool allconst = true;
00666               for (int i = 0; i < I->src[s].size(); i++) {
00667                 if (cp->src[s][i].state != ConstProp::Cell::CONSTANT) {
00668                   allconst = false;
00669                   break;
00670                 }
00671                 newconst.setVariant(cp->src[s][i].value, i);
00672               }
00673               if (allconst) {
00674 #ifdef SH_DEBUG_CONSTPROP
00675                 SH_DEBUG_PRINT("Replaced {" << *I << "}.src[" << s << "] with " << newconst);
00676 #endif
00677                 I->src[s] = newconst;
00678               }
00679             }
00680           }
00681 
00682           if (!lift_uniforms || allconst) {
00683             //SH_DEBUG_PRINT("Skipping uniform lifting");
00684             continue;
00685           }
00686 
00687           bool alluniform = true;
00688           for (int s = 0; s < opInfo[I->op].arity; s++) {
00689             for (int i = 0; i < I->src[s].size(); i++) {
00690               if (cp->src[s][i].state != ConstProp::Cell::UNIFORM
00691                   && cp->src[s][i].state != ConstProp::Cell::CONSTANT) {
00692                 alluniform = false;
00693                 break;
00694               }
00695             }
00696             if (!alluniform) break;
00697           }
00698           if (!alluniform || I->dest.node()->kind() == SH_OUTPUT
00699               || I->dest.node()->kind() == SH_INOUT) {
00700 #ifdef SH_DEBUG_CONSTPROP
00701             SH_DEBUG_PRINT("Considering " << *I << " for uniform lifting");
00702 #endif          
00703             for (int s = 0; s < opInfo[I->op].arity; s++) {
00704               if (I->src[s].uniform()) {
00705 #ifdef SH_DEBUG_CONSTPROP
00706                 SH_DEBUG_PRINT(*I << ".src[" << s << "] is already a uniform");
00707 #endif
00708                 continue;
00709               }
00710 
00711               bool mixed = false;
00712               bool neg = false;
00713               ConstProp::ValueNum uniform = -1;
00714               std::vector<int> indices;
00715               
00716               for (int i = 0; i < I->src[s].size(); i++) {
00717                 if (cp->src[s][i].state == ConstProp::Cell::UNIFORM) {
00718                   if (uniform < 0) {
00719                     uniform = cp->src[s][i].uniform.valuenum;
00720                     neg = cp->src[s][i].uniform.neg;
00721                     indices.push_back(cp->src[s][i].uniform.index);
00722                   } else {
00723                     if (uniform != cp->src[s][i].uniform.valuenum
00724                         || neg != cp->src[s][i].uniform.neg) {
00725                       mixed = true;
00726                       break;
00727                     }
00728                     indices.push_back(cp->src[s][i].uniform.index);
00729                   }
00730                 } else {
00731                   // Can't lift this, unless we introduce intermediate instructions.
00732                   mixed = true;
00733                 }
00734               }
00735               
00736               if (uniform < 0) {
00737 #ifdef SH_DEBUG_CONSTPROP
00738                 SH_DEBUG_PRINT("{" << *I << "}.src[" << s << "] is not uniform");
00739 #endif
00740                 continue;
00741               }
00742               if (mixed) {
00743 #ifdef SH_DEBUG_CONSTPROP
00744                 SH_DEBUG_PRINT("{" << *I << "}.src[" << s << "] is mixed");
00745 #endif
00746                 continue;
00747               }
00748               ConstProp::Value* value = ConstProp::Value::get(uniform);
00749 
00750 
00751 #ifdef SH_DEBUG_CONSTPROP
00752               SH_DEBUG_PRINT("Lifting {" << *I << "}.src[" << s << "]: " << uniform);
00753 #endif
00754               int srcsize;
00755               if (value->type == ConstProp::Value::NODE) {
00756                 srcsize = value->node->size();
00757               } else {
00758                 srcsize = value->destsize;
00759               }
00760               ShSwizzle swizzle(srcsize, indices.size(), &(*indices.begin()));
00761               // Build a uniform to represent this computation.
00762               ShVariableNodePtr node = build_uniform(value, uniform);
00763               if (node) {
00764                 I->src[s] = ShVariable(node, swizzle, neg);
00765               } else {
00766 #ifdef SH_DEBUG_CONSTPROP
00767                 SH_DEBUG_PRINT("Could not lift " << *I << ".src[" << s << "] for some reason");
00768 #endif
00769               }
00770             }
00771           }
00772         }
00773       }
00774     }
00775 
00776     // Clean up
00777     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00778       // Remove constant propagation information.
00779       I->destroy_info<ConstProp>();
00780     }
00781   }
00782 
00783   ShVariableNodePtr build_uniform(ConstProp::Value* value,
00784                                   ConstProp::ValueNum valuenum)
00785   {
00786     if (value->type == ConstProp::Value::NODE) {
00787       return value->node;
00788     }
00789     
00790     ShContext::current()->enter(0);
00791     ShVariableNodePtr node = new ShVariableNode(SH_TEMP, value->destsize, value->destValueType);
00792     {
00793     std::ostringstream s;
00794     s << "dep_" << valuenum << "_" << value->name();
00795     node->name(s.str());
00796     }
00797     ShContext::current()->exit();
00798 
00799 #ifdef SH_DEBUG_CONSTPROP
00800     SH_DEBUG_PRINT("Lifting value #" << valuenum);
00801 #endif
00802 
00803     bool broken = false;
00804     
00805     ShProgram prg = SH_BEGIN_PROGRAM("uniform") {
00806       ShStatement stmt(node, value->op);
00807 
00808       for (int i = 0; i < opInfo[value->op].arity; i++) {
00809         stmt.src[i] = compute(value->src[i]);
00810         if (stmt.src[i].null()) broken = true;
00811       }
00812 
00813       ShContext::current()->parsing()->tokenizer.blockList()->addStatement(stmt);
00814     } SH_END;
00815 
00816 #ifdef SH_DEBUG_CONSTPROP
00817     {
00818       std::ostringstream s;
00819       s << "lifted_" << valuenum;
00820       std::string dotfilename(s.str() + ".dot");
00821       std::ofstream dot(dotfilename.c_str());
00822       prg.node()->ctrlGraph->graphvizDump(dot);
00823       dot.close();
00824       std::string cmdline = std::string("dot -Tps -o ") + s.str() + ".ps " + s.str() + ".dot";
00825       system(cmdline.c_str());
00826     }
00827 #endif
00828     
00829     if (broken) return 0;
00830     node->attach(prg.node());
00831 
00832     value->type = ConstProp::Value::NODE;
00833     value->node = node;
00834     
00835     return node;
00836   }
00837 
00838   ShVariable compute(const std::vector<ConstProp::Uniform>& src)
00839   {
00840     ConstProp::ValueNum v = -1;
00841     
00842     bool allsame = true;
00843     bool neg = false;
00844     std::vector<int> indices;
00845     std::vector<ShVariantCPtr> constvals;
00846     for (std::size_t i = 0; i < src.size(); i++) {
00847       if (src[i].constant) {
00848         if (v >= 0) {
00849           allsame = false;
00850           break;
00851         }
00852         constvals.push_back(src[i].constval);
00853       } else {
00854         if (v < 0 && constvals.empty()) {
00855           v = src[i].valuenum;
00856           neg = src[i].neg;
00857         } else {
00858           if (v != src[i].valuenum || neg != src[i].neg || !constvals.empty()) {
00859             allsame = false;
00860           }
00861         }
00862         indices.push_back(src[i].index);
00863       }
00864     }
00865     if (!allsame) {
00866       // Make intermediate variables, combine them together.
00867       ShVariable r = ShVariable(new ShVariableNode(SH_TEMP, src.size(), src[0].valueType()));
00868       
00869       for (std::size_t i = 0; i < src.size(); i++) {
00870         std::vector<ConstProp::Uniform> v;
00871         v.push_back(src[i]);
00872         ShVariable scalar = compute(v);
00873         ShStatement stmt(r(i), SH_OP_ASN, scalar);
00874         ShContext::current()->parsing()->tokenizer.blockList()->addStatement(stmt);
00875       }
00876       return r;
00877     }
00878 
00879     if (!constvals.empty()) {
00880       ShVariable var(new ShVariableNode(SH_CONST, constvals.size(), constvals[0]->valueType()));
00881       for(std::size_t i = 0; i < constvals.size(); ++i) var.setVariant(constvals[i], i);
00882       return var;
00883     }
00884     
00885     ConstProp::Value* value = ConstProp::Value::get(v);
00886     
00887     if (value->type == ConstProp::Value::NODE) {
00888       ShSwizzle swizzle(value->node->size(), indices.size(), &*indices.begin());
00889       return ShVariable(value->node, swizzle, neg);
00890     }
00891     if (value->type == ConstProp::Value::STMT) {
00892       ShVariableNodePtr node = new ShVariableNode(SH_TEMP, value->destsize, value->destValueType);
00893       ShStatement stmt(node, value->op);
00894 
00895       for (int i = 0; i < opInfo[value->op].arity; i++) {
00896         stmt.src[i] = compute(value->src[i]);
00897         if (stmt.src[i].null()) return ShVariable();
00898       }
00899 
00900       ShContext::current()->parsing()->tokenizer.blockList()->addStatement(stmt);
00901       ShSwizzle swizzle(node->size(), indices.size(), &*indices.begin());
00902       return ShVariable(node, swizzle, neg);
00903     }
00904 
00905 #ifdef SH_DEBUG_CONSTPROP
00906     SH_DEBUG_PRINT("Reached invalid point");
00907 #endif
00908 
00909     // Should never reach here.
00910     return ShVariable();
00911   }
00912 
00913   bool lift_uniforms;
00914 };
00915 
00916 }
00917 
00918 namespace SH {
00919 
00920 
00921 void propagate_constants(ShProgram& p)
00922 {
00923   ShCtrlGraphPtr graph = p.node()->ctrlGraph;
00924 
00925   ConstWorkList worklist;
00926 
00927   //ConstProp::Value::clear();
00928   
00929   InitConstProp init(p.node(), worklist);
00930   graph->dfs(init);
00931 
00932 #ifdef SH_DEBUG_CONSTPROP
00933   SH_DEBUG_PRINT("Const Prop Initial Values:");
00934   DumpConstProp dump_pre;
00935   graph->dfs(dump_pre);
00936 #endif
00937 
00938   
00939   while (!worklist.empty()) {
00940     ValueTracking::Def def = worklist.front(); worklist.pop();
00941     ValueTracking* vt = def.stmt->get_info<ValueTracking>();
00942     if (!vt) {
00943 #ifdef SH_DEBUG_CONSTPROP
00944       SH_DEBUG_PRINT(*def.stmt << " on worklist does not have VT information?");
00945 #endif
00946       continue;
00947     }
00948 
00949 
00950     for (ValueTracking::DefUseChain::iterator use = vt->uses[def.index].begin();
00951          use != vt->uses[def.index].end(); ++use) {
00952       if (use->kind != ValueTracking::Use::STMT) continue;
00953       ConstProp* cp = use->stmt->get_info<ConstProp>();
00954       if (!cp) {
00955 #ifdef SH_DEBUG_CONSTPROP
00956         SH_DEBUG_PRINT("Use " << *use->stmt << " does not have const prop information!");
00957 #endif
00958         continue;
00959       }
00960 
00961       ConstProp::Cell cell = cp->src[use->source][use->index];
00962 
00963       ValueTracking* ut = use->stmt->get_info<ValueTracking>();
00964       if (!ut) {
00965         // Should never happen...
00966 #ifdef SH_DEBUG_CONSTPROP
00967         SH_DEBUG_PRINT("Use " << *use->stmt << " on worklist does not have VT information?");
00968 #endif
00969         continue;
00970       }
00971 
00972 #ifdef SH_DEBUG_CONSTPROP
00973       SH_DEBUG_PRINT("Meeting cell for {" << *use->stmt
00974                      << "}.src" << use->source << "[" << use->index << "]");
00975 #endif
00976       
00977       ConstProp::Cell new_cell(ConstProp::Cell::TOP);
00978       
00979       for (ValueTracking::UseDefChain::iterator possdef
00980              = ut->defs[use->source][use->index].begin();
00981            possdef != ut->defs[use->source][use->index].end(); ++possdef) {
00982         ConstProp* dcp = possdef->stmt->get_info<ConstProp>();
00983         if (!dcp) {
00984 #ifdef SH_DEBUG_CONSTPROP
00985           SH_DEBUG_PRINT("Possible def " << *dcp->stmt << " on worklist does not have CP information?");
00986 #endif
00987           continue;
00988         }
00989         
00990         ConstProp::Cell destcell = dcp->dest[possdef->index];
00991 
00992         // If the use is negated, we need to change the cell
00993         if (use->stmt->src[use->source].neg()) {
00994           if (destcell.state == ConstProp::Cell::CONSTANT) {
00995             SH_DEBUG_ASSERT(destcell.value); // @todo type DEBUGGING
00996             destcell.value->negate();
00997           } else if (destcell.state == ConstProp::Cell::UNIFORM) {
00998             destcell.uniform.neg = !destcell.uniform.neg;
00999           }
01000         }
01001 #ifdef SH_DEBUG_CONSTPROP
01002         SH_DEBUG_PRINT("  meet(" << new_cell << ", " << destcell << ") = " <<
01003                        meet(new_cell, destcell));
01004 #endif
01005         new_cell = meet(new_cell, destcell);
01006       }
01007       
01008       if (cell != new_cell) {
01009 #ifdef SH_DEBUG_CONSTPROP
01010         SH_DEBUG_PRINT("  ...replacing cell");
01011 #endif
01012         cp->src[use->source][use->index] = new_cell;
01013         cp->updateDest(worklist);
01014       }
01015     }
01016   }
01017   // Now do something with our glorious newfound information.
01018 
01019   
01020 #ifdef SH_DEBUG_CONSTPROP
01021   ConstProp::Value::dump(std::cerr);
01022 
01023   DumpConstProp dump;
01024   graph->dfs(dump);
01025 #endif
01026   
01027   FinishConstProp finish(p.node()->target().find("gpu:") == 0
01028                          && !ShContext::current()->optimization_disabled("uniform lifting"));
01029   graph->dfs(finish);
01030   
01031 }
01032 
01033 }

Generated on Thu Jul 28 17:33:02 2005 for Sh by  doxygen 1.4.3-20050530