00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00027
#include <map>
00028
#include <string>
00029
#include <sstream>
00030
#include <algorithm>
00031
#include "ShAlgebra.hpp"
00032
#include "ShCtrlGraph.hpp"
00033
#include "ShDebug.hpp"
00034
#include "ShError.hpp"
00035
#include "ShOptimizations.hpp"
00036
#include "ShInternals.hpp"
00037
#include "ShEnvironment.hpp"
00038
#include "ShContext.hpp"
00039
#include "ShManipulator.hpp"
00040
#include "ShFixedManipulator.hpp"
00041
00042
namespace SH {
00043
00044 ShProgram connect(
ShProgram pa,
ShProgram pb)
00045 {
00046
ShProgramNodePtr a = pa.
node();
00047
ShProgramNodePtr b = pb.
node();
00048
00049
if( !a || !b ) SH_DEBUG_WARN(
"Connecting with a null ShProgram" );
00050
if( !a )
return b;
00051
if( !b )
return a;
00052
00053
int aosize = a->outputs.size();
00054
int bisize = b->inputs.size();
00055 std::string rtarget;
00056
00057
if (a->target().empty()) {
00058 rtarget = b->target();
00059 }
else {
00060
if (b->target().empty() || a->target() == b->target()) {
00061 rtarget = a->target();
00062 }
else {
00063 SH_DEBUG_WARN(
"Connecting two different targets. Using empty target for result.");
00064 rtarget =
"";
00065 }
00066 }
00067
00068
ShProgramNodePtr program =
new ShProgramNode(rtarget);
00069
00070
ShCtrlGraphNodePtr heada, taila, headb, tailb;
00071
00072 a->ctrlGraph->copy(heada, taila);
00073 b->ctrlGraph->copy(headb, tailb);
00074
00075 taila->append(headb);
00076
00077
ShCtrlGraphPtr new_graph =
new ShCtrlGraph(heada, tailb);
00078 program->ctrlGraph = new_graph;
00079
00080 program->inputs = a->inputs;
00081
00082
00083
if(aosize < bisize) {
00084 ShProgramNode::VarList::const_iterator II = b->inputs.begin();
00085
for(
int i = 0; i < aosize; ++i, ++II);
00086
for(; II != b->inputs.end(); ++II) {
00087 program->inputs.push_back(*II);
00088 }
00089 }
00090 program->outputs = b->outputs;
00091
00092
00093
if(aosize > bisize) {
00094 ShProgramNode::VarList::const_iterator II = a->outputs.begin();
00095
for(
int i = 0; i < bisize; ++i, ++II);
00096
for(; II != a->outputs.end(); ++II) {
00097 program->outputs.push_back(*II);
00098 }
00099 }
00100
00101 ShVariableReplacer::VarMap varMap;
00102
00103 ShContext::current()->enter(program);
00104
00105 ShProgramNode::VarList::const_iterator I, J;
00106
00107 ShProgramNode::VarList InOutInputs;
00108 ShProgramNode::VarList InOutOutputs;
00109
00110
00111
for (I = a->outputs.begin(), J = b->inputs.begin();
00112 I != a->outputs.end() && J != b->inputs.end(); ++I, ++J) {
00113
if((*I)->size() != (*J)->size()) {
00114 std::ostringstream err;
00115 err <<
"Cannot smash variables "
00116 << (*I)->nameOfType() <<
" " << (*I)->name() <<
" and "
00117 << (*J)->nameOfType() <<
" " << (*J)->name() <<
" with different sizes" << std::endl;
00118 err <<
"while connecting outputs: ";
00119 ShProgramNode::print(err, a->outputs) << std::endl;
00120 err <<
"to inputs: ";
00121 ShProgramNode::print(err, b->inputs) << std::endl;
00122 ShContext::current()->exit();
00123
shError(
ShAlgebraException(err.str()));
00124
return ShProgram(
ShProgramNodePtr(0));
00125 }
00126
ShVariableNodePtr n =
new ShVariableNode(SH_TEMP, (*I)->size());
00127 varMap[*I] = n;
00128 varMap[*J] = n;
00129
00130
if((*I)->kind() == SH_INOUT) InOutInputs.push_back((*I));
00131
if((*J)->kind() == SH_INOUT) InOutOutputs.push_back((*J));
00132 }
00133
00134
00135
00136
ShCtrlGraphNodePtr graphEntry;
00137
for (I = InOutInputs.begin(); I != InOutInputs.end(); ++I) {
00138
if(!graphEntry) graphEntry = program->ctrlGraph->prependEntry();
00139
ShVariableNodePtr newInput(
new ShVariableNode(SH_INPUT, (*I)->size(),
00140 (*I)->specialType()));
00141
if ((*I)->has_name()) {
00142 newInput->name((*I)->name());
00143 }
00144 std::replace(program->inputs.begin(), program->inputs.end(),
00145 (*I), newInput);
00146 program->inputs.pop_back();
00147
00148 graphEntry->block->addStatement(
ShStatement(
00149
ShVariable(varMap[*I]),
SH_OP_ASN,
ShVariable(newInput)));
00150 }
00151
00152
ShCtrlGraphNodePtr graphExit;
00153
for (I = InOutOutputs.begin(); I != InOutOutputs.end(); ++I) {
00154
if(!graphExit) graphExit = program->ctrlGraph->appendExit();
00155
ShVariableNodePtr newOutput(
new ShVariableNode(SH_OUTPUT, (*I)->size(),
00156 (*I)->specialType()));
00157
if ((*I)->has_name()) {
00158 newOutput->name((*I)->name());
00159 }
00160 std::replace(program->outputs.begin(), program->outputs.end(),
00161 (*I), newOutput);
00162 program->outputs.pop_back();
00163
00164 graphExit->block->addStatement(
ShStatement(
00165
ShVariable(newOutput),
SH_OP_ASN,
ShVariable(varMap[*I])));
00166 }
00167
00168 ShContext::current()->exit();
00169
00170 ShVariableReplacer replacer(varMap);
00171 program->ctrlGraph->dfs(replacer);
00172
00173
optimize(program);
00174
00175 program->collectVariables();
00176
return program;
00177 }
00178
00179 ShProgram combine(
ShProgram pa,
ShProgram pb)
00180 {
00181
ShProgramNodePtr a = pa.
node();
00182
ShProgramNodePtr b = pb.
node();
00183
00184 std::string rtarget;
00185
if( !a || !b ) SH_DEBUG_WARN(
"Connecting with a null ShProgram" );
00186
if (!a)
return b;
00187
if (!b)
return a;
00188
00189
if (a->target().empty()) {
00190 rtarget = b->target();
00191 }
else {
00192
if (b->target().empty() || a->target() == b->target()) {
00193 rtarget = a->target();
00194 }
else {
00195 rtarget =
"";
00196 }
00197 }
00198
00199
ShProgramNodePtr program =
new ShProgramNode(rtarget);
00200
00201
ShCtrlGraphNodePtr heada, taila, headb, tailb;
00202
00203 a->ctrlGraph->copy(heada, taila);
00204 b->ctrlGraph->copy(headb, tailb);
00205
00206 taila->append(headb);
00207
00208
ShCtrlGraphPtr new_graph =
new ShCtrlGraph(heada, tailb);
00209 program->ctrlGraph = new_graph;
00210
00211 program->inputs = a->inputs;
00212 program->inputs.insert(program->inputs.end(), b->inputs.begin(), b->inputs.end());
00213 program->outputs = a->outputs;
00214 program->outputs.insert(program->outputs.end(), b->outputs.begin(), b->outputs.end());
00215
00216
optimize(program);
00217
00218 program->collectVariables();
00219
00220
return program;
00221 }
00222
00223
00224 ShProgram mergeNames(ShProgram p)
00225 {
00226
typedef std::pair<std::string, int> InputType;
00227
typedef std::map< InputType, int > FirstOccurenceMap;
00228
typedef std::vector< std::vector<int> > Duplicates;
00229 FirstOccurenceMap firsts;
00230
00231
00232 Duplicates dups( p.node()->inputs.size(), std::vector<int>());
00233
00234 std::size_t i = 0;
00235
for(ShProgramNode::VarList::const_iterator I = p.node()->inputs.begin();
00236 I != p.node()->inputs.end(); ++I, ++i) {
00237 InputType it( (*I)->name(), (*I)->size() );
00238
if( firsts.find( it ) != firsts.end() ) {
00239 dups[ firsts[it] ].push_back(i);
00240 }
else {
00241 firsts[it] = i;
00242 dups[i].push_back(i);
00243 }
00244 }
00245 std::vector<int> swizzle;
00246 ShFixedManipulator duplicator;
00247
for(i = 0; i < dups.size(); ++i) {
00248
if( dups[i].empty() )
continue;
00249
for(std::size_t j = 0; j < dups[i].size(); ++j) swizzle.push_back(dups[i][j]);
00250
if( duplicator ) duplicator = duplicator & shDup(dups[i].size());
00251
else duplicator = shDup(dups[i].size());
00252 }
00253 ShProgram result = p <<
shSwizzle(swizzle);
00254
if( duplicator ) result = result << duplicator;
00255
return result.node();
00256 }
00257
00258 ShProgram namedCombine(
ShProgram a,
ShProgram b) {
00259
return mergeNames(
combine(a, b));
00260 }
00261
00262 ShProgram namedConnect(
ShProgram pa,
ShProgram pb,
bool keepExtra)
00263 {
00264
ShProgramNodeCPtr a = pa.
node();
00265
ShProgramNodeCPtr b = pb.
node();
00266
00267
typedef std::map<int, int> MatchedChannelMap;
00268
00269 std::vector<bool> aMatch(a->outputs.size(),
false);
00270 std::vector<bool> bMatch(b->inputs.size(),
false);
00271 MatchedChannelMap mcm;
00272 std::size_t i, j;
00273 ShProgramNode::VarList::const_iterator I, J;
00274
00275 i = 0;
00276
for(I = a->outputs.begin(); I != a->outputs.end(); ++I, ++i) {
00277 j = 0;
00278
for(J = b->inputs.begin(); J != b->inputs.end(); ++J, ++j) {
00279
if(bMatch[j])
continue;
00280
if((*I)->name() != (*J)->name())
continue;
00281
if((*I)->size() != (*J)->size()) {
00282 SH_DEBUG_WARN(
"Named connect matched channel name " << (*I)->name()
00283 <<
" but output size " << (*I)->size() <<
" != " <<
" input size " << (*J)->size() );
00284
continue;
00285 }
00286 mcm[i] = j;
00287 aMatch[i] =
true;
00288 bMatch[j] =
true;
00289 }
00290 }
00291
00292 std::vector<int> swiz(b->inputs.size(), 0);
00293
for(MatchedChannelMap::iterator mcmit = mcm.begin(); mcmit != mcm.end(); ++mcmit) {
00294 swiz[mcmit->second] = mcmit->first;
00295 }
00296
00297
00298
ShProgram passer =
SH_BEGIN_PROGRAM() {} SH_END;
00299
int newInputIdx = a->outputs.size();
00300
for(j = 0, J= b->inputs.begin(); J != b->inputs.end(); ++J, ++j) {
00301
if( !bMatch[j] ) {
00302
ShProgram passOne =
SH_BEGIN_PROGRAM() {
00303
ShVariable var(
new ShVariableNode(SH_INOUT, (*J)->size(),
00304 (*J)->specialType()));
00305 var.name((*J)->name());
00306 } SH_END;
00307 passer = passer & passOne;
00308 swiz[j] = newInputIdx++;
00309 }
00310 }
00311
ShProgram aPass =
combine(pa, passer);
00312
00313
if( keepExtra ) {
00314
for(i = 0; i < aMatch.size(); ++i) {
00315
if( !aMatch[i] ) swiz.push_back(i);
00316 }
00317 }
00318
00319
return mergeNames(pb << (
shSwizzle(swiz) << aPass ));
00320 }
00321
00322 ShProgram renameInput(
ShProgram a,
00323
const std::string& oldName,
const std::string& newName) {
00324
ShProgram renamer =
SH_BEGIN_PROGRAM() {
00325
for(ShProgramNode::VarList::const_iterator I = a.
node()->inputs.begin();
00326 I != a.
node()->inputs.end(); ++I) {
00327
ShVariable var(
new ShVariableNode(SH_INOUT, (*I)->size(),
00328 (*I)->specialType()));
00329
00330
if (!(*I)->has_name())
continue;
00331 std::string name = (*I)->name();
00332
if( name == oldName ) {
00333 var.name(newName);
00334 }
else {
00335 var.name(name);
00336 }
00337 }
00338 } SH_END;
00339
return connect(renamer, a);
00340 }
00341
00342
00343 ShProgram renameOutput(
ShProgram a,
00344
const std::string& oldName,
const std::string& newName) {
00345
ShProgram renamer =
SH_BEGIN_PROGRAM() {
00346
for(ShProgramNode::VarList::const_iterator I = a.
node()->outputs.begin();
00347 I != a.
node()->outputs.end(); ++I) {
00348
ShVariable var(
new ShVariableNode(SH_INOUT, (*I)->size(),
00349 (*I)->specialType()));
00350
00351
if (!(*I)->has_name())
continue;
00352 std::string name = (*I)->name();
00353
if( name == oldName ) {
00354 var.name(newName);
00355 }
else {
00356 var.name(name);
00357 }
00358 }
00359 } SH_END;
00360
return connect(a, renamer);
00361 }
00362
00363 ShProgram namedAlign(
ShProgram a,
ShProgram b) {
00364
ShManipulator<std::string> ordering;
00365
00366
for(ShProgramNode::VarList::const_iterator I = b.
node()->inputs.begin();
00367 I != b.
node()->inputs.end(); ++I) {
00368 ordering((*I)->name());
00369 }
00370
00371
return ordering << a;
00372 }
00373
00374 ShProgram operator<<(
ShProgram a,
ShProgram b)
00375 {
00376
return connect(b,a);
00377 }
00378
00379 ShProgram
operator>>(ShProgram a, ShProgram b)
00380 {
00381
return connect(a,b);
00382 }
00383
00384 ShProgram operator&(
ShProgram a,
ShProgram b)
00385 {
00386
return combine(a, b);
00387 }
00388
00389 ShProgram operator>>(
ShProgram p,
const ShVariable &var) {
00390
return replaceUniform(p, var);
00391 }
00392
00393 ShProgram replaceUniform(
ShProgram a,
const ShVariable& v)
00394 {
00395
if(!v.
uniform()) {
00396
shError(
ShAlgebraException(
"Cannot replace non-uniform variable"));
00397 }
00398
00399
ShProgram program(a.
node()->clone());
00400
00401 ShVariableReplacer::VarMap varMap;
00402
00403 ShContext::current()->enter(program.
node());
00404
00405
00406
ShVariableNodePtr newInput =
new ShVariableNode(SH_INPUT, v.
size(), v.
node()->specialType());
00407 varMap[v.
node()] = newInput;
00408
00409 ShContext::current()->exit();
00410
00411 ShVariableReplacer replacer(varMap);
00412 program.
node()->ctrlGraph->dfs(replacer);
00413
00414
optimize(program);
00415
00416 program.
node()->collectVariables();
00417
00418
return program;
00419 }
00420
00421 }