Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members | Related Pages

ShAlgebra.cpp

00001 // Sh: A GPU metaprogramming language. 00002 // 00003 // Copyright (c) 2003 University of Waterloo Computer Graphics Laboratory 00004 // Project administrator: Michael D. McCool 00005 // Authors: Zheng Qin, Stefanus Du Toit, Kevin Moule, Tiberiu S. Popa, 00006 // Michael D. McCool 00007 // 00008 // This software is provided 'as-is', without any express or implied 00009 // warranty. In no event will the authors be held liable for any damages 00010 // arising from the use of this software. 00011 // 00012 // Permission is granted to anyone to use this software for any purpose, 00013 // including commercial applications, and to alter it and redistribute it 00014 // freely, subject to the following restrictions: 00015 // 00016 // 1. The origin of this software must not be misrepresented; you must 00017 // not claim that you wrote the original software. If you use this 00018 // software in a product, an acknowledgment in the product documentation 00019 // would be appreciated but is not required. 00020 // 00021 // 2. Altered source versions must be plainly marked as such, and must 00022 // not be misrepresented as being the original software. 00023 // 00024 // 3. This notice may not be removed or altered from any source 00025 // distribution. 00027 #include <map> 00028 #include <string> 00029 #include <sstream> 00030 #include <algorithm> 00031 #include "ShAlgebra.hpp" 00032 #include "ShCtrlGraph.hpp" 00033 #include "ShDebug.hpp" 00034 #include "ShError.hpp" 00035 #include "ShOptimizations.hpp" 00036 #include "ShInternals.hpp" 00037 #include "ShEnvironment.hpp" 00038 #include "ShContext.hpp" 00039 #include "ShManipulator.hpp" 00040 #include "ShFixedManipulator.hpp" 00041 00042 namespace SH { 00043 00044 ShProgram connect(ShProgram pa, ShProgram pb) 00045 { 00046 ShProgramNodePtr a = pa.node(); 00047 ShProgramNodePtr b = pb.node(); 00048 00049 if( !a || !b ) SH_DEBUG_WARN( "Connecting with a null ShProgram" ); 00050 if( !a ) return b; 00051 if( !b ) return a; 00052 00053 int aosize = a->outputs.size(); 00054 int bisize = b->inputs.size(); 00055 std::string rtarget; 00056 00057 if (a->target().empty()) { 00058 rtarget = b->target(); // A doesn't have a target. Use b's. 00059 } else { 00060 if (b->target().empty() || a->target() == b->target()) { 00061 rtarget = a->target(); // A has a target, b doesn't 00062 } else { 00063 SH_DEBUG_WARN("Connecting two different targets. Using empty target for result."); 00064 rtarget = ""; // Connecting different targets. 00065 } 00066 } 00067 00068 ShProgramNodePtr program = new ShProgramNode(rtarget); 00069 00070 ShCtrlGraphNodePtr heada, taila, headb, tailb; 00071 00072 a->ctrlGraph->copy(heada, taila); 00073 b->ctrlGraph->copy(headb, tailb); 00074 00075 taila->append(headb); 00076 00077 ShCtrlGraphPtr new_graph = new ShCtrlGraph(heada, tailb); 00078 program->ctrlGraph = new_graph; 00079 00080 program->inputs = a->inputs; 00081 00082 // push back extra inputs from b if aosize < bisize 00083 if(aosize < bisize) { 00084 ShProgramNode::VarList::const_iterator II = b->inputs.begin(); 00085 for(int i = 0; i < aosize; ++i, ++II); 00086 for(; II != b->inputs.end(); ++II) { 00087 program->inputs.push_back(*II); 00088 } 00089 } 00090 program->outputs = b->outputs; 00091 00092 // push back extra outputs from a if aosize > bisize 00093 if(aosize > bisize) { 00094 ShProgramNode::VarList::const_iterator II = a->outputs.begin(); 00095 for(int i = 0; i < bisize; ++i, ++II); 00096 for(; II != a->outputs.end(); ++II) { 00097 program->outputs.push_back(*II); 00098 } 00099 } 00100 00101 ShVariableReplacer::VarMap varMap; 00102 00103 ShContext::current()->enter(program); 00104 00105 ShProgramNode::VarList::const_iterator I, J; 00106 00107 ShProgramNode::VarList InOutInputs; 00108 ShProgramNode::VarList InOutOutputs; 00109 00110 // replace outputs and inputs connected together by temps 00111 for (I = a->outputs.begin(), J = b->inputs.begin(); 00112 I != a->outputs.end() && J != b->inputs.end(); ++I, ++J) { 00113 if((*I)->size() != (*J)->size()) { 00114 std::ostringstream err; 00115 err << "Cannot smash variables " 00116 << (*I)->nameOfType() << " " << (*I)->name() << " and " 00117 << (*J)->nameOfType() << " " << (*J)->name() << " with different sizes" << std::endl; 00118 err << "while connecting outputs: "; 00119 ShProgramNode::print(err, a->outputs) << std::endl; 00120 err << "to inputs: "; 00121 ShProgramNode::print(err, b->inputs) << std::endl; 00122 ShContext::current()->exit(); 00123 shError(ShAlgebraException(err.str())); 00124 return ShProgram(ShProgramNodePtr(0)); 00125 } 00126 ShVariableNodePtr n = new ShVariableNode(SH_TEMP, (*I)->size()); 00127 varMap[*I] = n; 00128 varMap[*J] = n; 00129 00130 if((*I)->kind() == SH_INOUT) InOutInputs.push_back((*I)); 00131 if((*J)->kind() == SH_INOUT) InOutOutputs.push_back((*J)); 00132 } 00133 00134 // Change connected InOut variables to either Input or Output only 00135 // (since they have been connected and turned into temps internally) 00136 ShCtrlGraphNodePtr graphEntry; 00137 for (I = InOutInputs.begin(); I != InOutInputs.end(); ++I) { 00138 if(!graphEntry) graphEntry = program->ctrlGraph->prependEntry(); 00139 ShVariableNodePtr newInput(new ShVariableNode(SH_INPUT, (*I)->size(), 00140 (*I)->specialType())); 00141 if ((*I)->has_name()) { 00142 newInput->name((*I)->name()); 00143 } 00144 std::replace(program->inputs.begin(), program->inputs.end(), 00145 (*I), newInput); 00146 program->inputs.pop_back(); 00147 00148 graphEntry->block->addStatement(ShStatement( 00149 ShVariable(varMap[*I]), SH_OP_ASN, ShVariable(newInput))); 00150 } 00151 00152 ShCtrlGraphNodePtr graphExit; 00153 for (I = InOutOutputs.begin(); I != InOutOutputs.end(); ++I) { 00154 if(!graphExit) graphExit = program->ctrlGraph->appendExit(); 00155 ShVariableNodePtr newOutput(new ShVariableNode(SH_OUTPUT, (*I)->size(), 00156 (*I)->specialType())); 00157 if ((*I)->has_name()) { 00158 newOutput->name((*I)->name()); 00159 } 00160 std::replace(program->outputs.begin(), program->outputs.end(), 00161 (*I), newOutput); 00162 program->outputs.pop_back(); 00163 00164 graphExit->block->addStatement(ShStatement( 00165 ShVariable(newOutput), SH_OP_ASN, ShVariable(varMap[*I]))); 00166 } 00167 00168 ShContext::current()->exit(); 00169 00170 ShVariableReplacer replacer(varMap); 00171 program->ctrlGraph->dfs(replacer); 00172 00173 optimize(program); 00174 00175 program->collectVariables(); 00176 return program; 00177 } 00178 00179 ShProgram combine(ShProgram pa, ShProgram pb) 00180 { 00181 ShProgramNodePtr a = pa.node(); 00182 ShProgramNodePtr b = pb.node(); 00183 00184 std::string rtarget; 00185 if( !a || !b ) SH_DEBUG_WARN( "Connecting with a null ShProgram" ); 00186 if (!a) return b; 00187 if (!b) return a; 00188 00189 if (a->target().empty()) { 00190 rtarget = b->target(); // A doesn't have a target. Use b's. 00191 } else { 00192 if (b->target().empty() || a->target() == b->target()) { 00193 rtarget = a->target(); // A has a target, b doesn't 00194 } else { 00195 rtarget = ""; // Connecting different targets. 00196 } 00197 } 00198 00199 ShProgramNodePtr program = new ShProgramNode(rtarget); 00200 00201 ShCtrlGraphNodePtr heada, taila, headb, tailb; 00202 00203 a->ctrlGraph->copy(heada, taila); 00204 b->ctrlGraph->copy(headb, tailb); 00205 00206 taila->append(headb); 00207 00208 ShCtrlGraphPtr new_graph = new ShCtrlGraph(heada, tailb); 00209 program->ctrlGraph = new_graph; 00210 00211 program->inputs = a->inputs; 00212 program->inputs.insert(program->inputs.end(), b->inputs.begin(), b->inputs.end()); 00213 program->outputs = a->outputs; 00214 program->outputs.insert(program->outputs.end(), b->outputs.begin(), b->outputs.end()); 00215 00216 optimize(program); 00217 00218 program->collectVariables(); 00219 00220 return program; 00221 } 00222 00223 // Duplicates to inputs with matching name/type 00224 ShProgram mergeNames(ShProgram p) 00225 { 00226 typedef std::pair<std::string, int> InputType; 00227 typedef std::map< InputType, int > FirstOccurenceMap; // position of first occurence of an input type 00228 typedef std::vector< std::vector<int> > Duplicates; 00229 FirstOccurenceMap firsts; 00230 // dups[i] stores the set of positions that have matching input types with position i. 00231 // The whole set is stored in the smallest i position. 00232 Duplicates dups( p.node()->inputs.size(), std::vector<int>()); 00233 00234 std::size_t i = 0; 00235 for(ShProgramNode::VarList::const_iterator I = p.node()->inputs.begin(); 00236 I != p.node()->inputs.end(); ++I, ++i) { 00237 InputType it( (*I)->name(), (*I)->size() ); 00238 if( firsts.find( it ) != firsts.end() ) { // duplicate 00239 dups[ firsts[it] ].push_back(i); 00240 } else { 00241 firsts[it] = i; 00242 dups[i].push_back(i); 00243 } 00244 } 00245 std::vector<int> swizzle; 00246 ShFixedManipulator duplicator; 00247 for(i = 0; i < dups.size(); ++i) { 00248 if( dups[i].empty() ) continue; 00249 for(std::size_t j = 0; j < dups[i].size(); ++j) swizzle.push_back(dups[i][j]); 00250 if( duplicator ) duplicator = duplicator & shDup(dups[i].size()); 00251 else duplicator = shDup(dups[i].size()); 00252 } 00253 ShProgram result = p << shSwizzle(swizzle); 00254 if( duplicator ) result = result << duplicator; 00255 return result.node(); 00256 } 00257 00258 ShProgram namedCombine(ShProgram a, ShProgram b) { 00259 return mergeNames(combine(a, b)); 00260 } 00261 00262 ShProgram namedConnect(ShProgram pa, ShProgram pb, bool keepExtra) 00263 { 00264 ShProgramNodeCPtr a = pa.node(); 00265 ShProgramNodeCPtr b = pb.node(); 00266 // positions of a pair of matched a output and b input 00267 typedef std::map<int, int> MatchedChannelMap; 00268 00269 std::vector<bool> aMatch(a->outputs.size(), false); 00270 std::vector<bool> bMatch(b->inputs.size(), false); 00271 MatchedChannelMap mcm; 00272 std::size_t i, j; 00273 ShProgramNode::VarList::const_iterator I, J; 00274 00275 i = 0; 00276 for(I = a->outputs.begin(); I != a->outputs.end(); ++I, ++i) { 00277 j = 0; 00278 for(J = b->inputs.begin(); J != b->inputs.end(); ++J, ++j) { 00279 if(bMatch[j]) continue; 00280 if((*I)->name() != (*J)->name()) continue; 00281 if((*I)->size() != (*J)->size()) { 00282 SH_DEBUG_WARN("Named connect matched channel name " << (*I)->name() 00283 << " but output size " << (*I)->size() << " != " << " input size " << (*J)->size() ); 00284 continue; 00285 } 00286 mcm[i] = j; 00287 aMatch[i] = true; 00288 bMatch[j] = true; 00289 } 00290 } 00291 00292 std::vector<int> swiz(b->inputs.size(), 0); 00293 for(MatchedChannelMap::iterator mcmit = mcm.begin(); mcmit != mcm.end(); ++mcmit) { 00294 swiz[mcmit->second] = mcmit->first; 00295 } 00296 00297 // swizzle unmatched inputs and make a pass them through properly 00298 ShProgram passer = SH_BEGIN_PROGRAM() {} SH_END; 00299 int newInputIdx = a->outputs.size(); // index of next new input added to a 00300 for(j = 0, J= b->inputs.begin(); J != b->inputs.end(); ++J, ++j) { 00301 if( !bMatch[j] ) { 00302 ShProgram passOne = SH_BEGIN_PROGRAM() { 00303 ShVariable var(new ShVariableNode(SH_INOUT, (*J)->size(), 00304 (*J)->specialType())); 00305 var.name((*J)->name()); 00306 } SH_END; 00307 passer = passer & passOne; 00308 swiz[j] = newInputIdx++; 00309 } 00310 } 00311 ShProgram aPass = combine(pa, passer); 00312 00313 if( keepExtra ) { 00314 for(i = 0; i < aMatch.size(); ++i) { 00315 if( !aMatch[i] ) swiz.push_back(i); 00316 } 00317 } 00318 00319 return mergeNames(pb << ( shSwizzle(swiz) << aPass )); 00320 } 00321 00322 ShProgram renameInput(ShProgram a, 00323 const std::string& oldName, const std::string& newName) { 00324 ShProgram renamer = SH_BEGIN_PROGRAM() { 00325 for(ShProgramNode::VarList::const_iterator I = a.node()->inputs.begin(); 00326 I != a.node()->inputs.end(); ++I) { 00327 ShVariable var(new ShVariableNode(SH_INOUT, (*I)->size(), 00328 (*I)->specialType())); 00329 00330 if (!(*I)->has_name()) continue; 00331 std::string name = (*I)->name(); 00332 if( name == oldName ) { 00333 var.name(newName); 00334 } else { 00335 var.name(name); 00336 } 00337 } 00338 } SH_END; 00339 return connect(renamer, a); 00340 } 00341 00342 // TODO factor out common code from renameInput, renameOutput 00343 ShProgram renameOutput(ShProgram a, 00344 const std::string& oldName, const std::string& newName) { 00345 ShProgram renamer = SH_BEGIN_PROGRAM() { 00346 for(ShProgramNode::VarList::const_iterator I = a.node()->outputs.begin(); 00347 I != a.node()->outputs.end(); ++I) { 00348 ShVariable var(new ShVariableNode(SH_INOUT, (*I)->size(), 00349 (*I)->specialType())); 00350 00351 if (!(*I)->has_name()) continue; 00352 std::string name = (*I)->name(); 00353 if( name == oldName ) { 00354 var.name(newName); 00355 } else { 00356 var.name(name); 00357 } 00358 } 00359 } SH_END; 00360 return connect(a, renamer); 00361 } 00362 00363 ShProgram namedAlign(ShProgram a, ShProgram b) { 00364 ShManipulator<std::string> ordering; 00365 00366 for(ShProgramNode::VarList::const_iterator I = b.node()->inputs.begin(); 00367 I != b.node()->inputs.end(); ++I) { 00368 ordering((*I)->name()); 00369 } 00370 00371 return ordering << a; 00372 } 00373 00374 ShProgram operator<<(ShProgram a, ShProgram b) 00375 { 00376 return connect(b,a); 00377 } 00378 00379 ShProgram operator>>(ShProgram a, ShProgram b) 00380 { 00381 return connect(a,b); 00382 } 00383 00384 ShProgram operator&(ShProgram a, ShProgram b) 00385 { 00386 return combine(a, b); 00387 } 00388 00389 ShProgram operator>>(ShProgram p, const ShVariable &var) { 00390 return replaceUniform(p, var); 00391 } 00392 00393 ShProgram replaceUniform(ShProgram a, const ShVariable& v) 00394 { 00395 if(!v.uniform()) { 00396 shError(ShAlgebraException("Cannot replace non-uniform variable")); 00397 } 00398 00399 ShProgram program(a.node()->clone()); 00400 00401 ShVariableReplacer::VarMap varMap; 00402 00403 ShContext::current()->enter(program.node()); 00404 00405 // make a new input 00406 ShVariableNodePtr newInput = new ShVariableNode(SH_INPUT, v.size(), v.node()->specialType()); 00407 varMap[v.node()] = newInput; 00408 00409 ShContext::current()->exit(); 00410 00411 ShVariableReplacer replacer(varMap); 00412 program.node()->ctrlGraph->dfs(replacer); 00413 00414 optimize(program); 00415 00416 program.node()->collectVariables(); 00417 00418 return program; 00419 } 00420 00421 }

Generated on Mon Oct 18 14:17:38 2004 for Sh by doxygen 1.3.7