source: container/qg_transform.py @ 245

Last change on this file since 245 was 245, checked in by marchulsman, 11 years ago

Bug in qg_transform where combination

File size: 34.2 KB
Line 
1import operator
2import traceback
3
4from basic_pass import Pass
5from qgraph import NodeVisitorFactory,Node
6from multi_visitor import NF_ELSE,NF_ROBJ,F_CACHE
7import qgraph_py
8import itypes_py
9from itypes_py import SegmentAll, SegmentFields,SegmentSet,SegmentInFix,SegmentFix,SegmentOutFix, SegmentPassThrough
10import utility
11import opcon
12from collections import deque
13
14
15
16def getLocalId(fieldid,fieldid_source,node):
17    if(node.out_calc is None):
18        node_calcids = {}
19    else:
20        node_calcids = node.out_calc
21    node_fieldids = node.fields
22
23    if(operator.isSequenceType(fieldid)):
24        nfieldid = []
25        for fid in fieldid:
26            while(not fid in node_fieldids and not fid in node_calcids):
27                try:
28                    fid = fieldid_source[fid]
29                    assert isinstance(fid,int),"Unexpected source id type (not an int)!"
30                except KeyError:
31                    raise RuntimeError,"Expected fieldid " + str(fid) + " could not be found in " + str(node.obj)
32            nfieldid.append(fid)
33        fieldid = nfieldid
34    else:
35        while(not fieldid in node_fieldids and not fieldid in node_calcids):
36            try:
37                fieldid = fieldid_source[fieldid]
38                assert isinstance(fieldid,int),"Unexpected source id!"
39            except KeyError:
40                raise RuntimeError,"Expected fieldid " + str(fieldid) + " could not be found in " + str(node.obj)
41       
42    return fieldid
43
44#ConsolidationPass: removes unused nodes{{{
45class ConsolidationPass(NodeVisitorFactory(flags=F_CACHE | NF_ROBJ),Pass):
46    after=(qgraph_py.PostOrderWalkPass,)
47
48    @classmethod
49    def run(cls,query):
50        """walk nodes in parent-first order, remove nodes from
51        depwalk that are removed from query graph by filtering
52        on output of visit method"""
53
54        self = cls(query.params) 
55        depwalk = query.pass_results[qgraph_py.PostOrderWalkPass]
56        depwalk = filter(self.visit,depwalk[::-1])[::-1]
57       
58        #update depwalk
59        query.pass_results[qgraph_py.PostOrderWalkPass] = depwalk
60        return (None,None,None)
61   
62    #most nodes require no action
63    def visitnode(self,node):
64        """Visits all nodes not to be removed. Returns True"""
65        return True
66   
67    #removes an unary node from the query graph
68    def removeUnaryNode(self,node):
69        """Remove nodes with only one source. Returns False."""
70        source_node = node.source[0]
71       
72        #targets of this node should point to its source
73        for target_node in node.target:
74            while(node in target_node.source):
75                node_idx = target_node.source.index(node)
76                target_node.source[node_idx] = source_node
77                source_node.target.append(target_node)
78       
79        #source node should not point to this node anymore
80        del source_node.target[source_node.target.index(node)]
81       
82        #cleanup (lets help the garbage collector)
83        node.source = []
84        node.target = []
85       
86        #returns False on success (used by filter as indication which nodes are removed)
87        return False
88
89    #node types to be removed
90    visitSelectCon = removeUnaryNode
91    visitUpCon = removeUnaryNode
92    visitNewCapCon = removeUnaryNode
93    visitChangeFieldNameCon = removeUnaryNode
94
95#}}}
96
97class PeepHoleOptimizerPass(NodeVisitorFactory(flags=NF_ROBJ),Pass):#{{{
98    after=(qgraph_py.FieldIdMapPass,qgraph_py.PostOrderWalkPass,ConsolidationPass)
99
100    NO_CHANGE = 0
101    CHANGE = 1
102    NEXT_ROUND = 2
103
104    @classmethod
105    def run(cls,query):
106        """Walks through nodes, attempting local optimization.
107        Each node is visited once in the first round, and can
108        decide to be part of the next round (indicated by return value).
109        The next round will only start if there is change in
110        the first round. (also indicated by return value)
111        Nodes can also be queued for visit by adding them manually
112        to self.cur_round or self.next_round. Nodes that get deleted
113        should be added to self.deleted_nodes to prevent their revisiting.
114        """
115        self = cls(query.params)
116        #initalization: get results from previous passes
117        depwalk = query.pass_results[qgraph_py.PostOrderWalkPass]
118        (self.fieldid_map,self.fieldid_source)= query.pass_results[qgraph_py.FieldIdMapPass]
119
120        self.cur_round = deque(depwalk) #children first
121        self.deleted_nodes = set()      #nodes deleted during optimization, not to be visited again
122        while(self.cur_round):
123            self.next_round = deque()   #nodes to visit in next round
124            change = False              #change in this round?
125            while(self.cur_round):
126                node = self.cur_round.popleft()
127                if(node in self.deleted_nodes):
128                    continue
129                res = self.visit(node)
130                if(res & self.CHANGE):
131                    change = True
132                    #debug: show tree after each change
133                    #qgraph_py.DrawNodeTreeVisitor.run(query)
134                if(res & self.NEXT_ROUND):
135                    self.next_round.append(node)
136
137            if(change): #if change, perform a new round
138                self.cur_round = self.next_round
139
140        #invalidate Postorderwalkpass
141        return (None,(qgraph_py.PostOrderWalkPass,),None)
142
143
144    ############# UTILITY FUNCTIONS ###################
145    def removeSource(self,node,pos):
146       """Removes source from node. Also removes accompanying segment,
147       and updates source id of other segments. Returns removed segment"""
148
149       nsegments = []
150       rsegment = None
151       for segment in node.segments:
152           assert not operator.isSequenceType(segment.source), "removeSource cannot handle segments with more than one source"
153           if(segment.source == pos):
154               rsegment = segment
155               continue
156           elif(segment.source > pos):
157               segment = segment.lazy_copy()
158               segment.source -= 1
159           nsegments.append(segment)
160       node.segments = nsegments
161       del node.source[pos]
162       return rsegment
163   
164   
165    def addCalcFields(self,node,calcfields):
166        """Add calcfields dictionary to node. If node already has a calcfields
167        dictionary, update it with the extra info."""
168   
169        if(node.out_calc is None):
170            node.out_calc = calcfields
171        else:
172            node.out_calc.update(calcfields)
173        self.next_round.append(node)
174
175    def moveUpCalcFields(self,node,source):
176       """Removes calcfields from node and adds them to source"""
177       if(not node.out_calc is None):
178           self.addCalcFields(source,node.out_calc)
179           node.out_calc = None
180   
181    def moveUpSetNodeCalcFields(self,node):#{{{
182        """Used for nodes which have only set segments  (join, merge,etc)
183        to move the calcfields to their source nodes if possible."""
184       
185        #if no calcfields present, return
186        if(node.out_calc is None):
187            return 0
188       
189        assert (len(set([segment.__class__ for segment in node.segments])) == 1),"Only SegmentSet expected in moveUpSetNodeCalcFields"
190
191        #obtain list of calcfields in this node
192        ch_list = node.out_calc.values()
193       
194        #sort calcfields on dependency by looking at field id
195        ch_list.sort(key=operator.itemgetter(0))
196       
197        change_status = False
198
199        #walk through calcfields that will be calculated here,
200        #parents firsT
201        for cid,param1,param2,op,type in ch_list:
202            for pos,segment in enumerate(node.segments):
203                field_set = segment.attr
204                if(param1 >= 0):
205                    param1 = getLocalId(param1,self.fieldid_source,node)
206                if(not param2 is None and param2 >= 0):
207                    param2 = getLocalId(param2,self.fieldid_source,node)
208
209                #check if calcfield dependencies fall completely within this segment
210                if((param1 < 0 or param1 in field_set) and (param2 is None or param2 < 0  or param2 in field_set)):
211                    #obtain source
212                    source = node.source[segment.source]
213
214                    #make copy of segment and add calculated field
215                    segment = segment.lazy_copy()
216                    node.segments[pos] = segment
217                   
218                    #add calcfield as field to this node
219                    node.fields.add(cid)
220                    segment.addField(cid)
221                   
222                    #source should be looked at again
223                    if(source not in self.next_round):
224                        self.next_round.append(source)
225                   
226                    #remove calc tuple from current node
227                    del node.out_calc[cid]
228
229                    #add new calc tuple to source node
230                    ncalc = (cid,param1,param2,op,type)
231                    self.addCalcFields(source,{ncalc[0]:ncalc})
232                   
233                    change_status = True
234
235                    break 
236        if(change_status):
237            #if calc environment empty, delete it
238            if(len(node.out_calc) == 0):
239                node.out_calc = None
240            return self.CHANGE
241        else:
242            return 0
243        #}}}
244
245    def visitWhereCon(self,node):#{{{
246        """Optimization of where nodes.
247           1) If parent is also a where node, combines this where node
248              with parent and removes this where node
249           2) If parent is join node, attempts to move constraints before
250              join
251        """
252
253        #has this node a single source that only references this node?
254        if(not (len(set(node.source)) == 1 \
255                and len(set(node.source[0].target))==1)):
256            #if not, we cannot do anything now, lets wait until next round
257            return self.NEXT_ROUND
258
259        #is this a where which uses a fixed constraint (i.e. a slice or index?)
260        if(len(node.segments) == 1):
261            #cannot be reordered, so have to keep it as is
262            return 0 
263
264        parent = node.source[0]  #single parent for this where node
265        not_remove_idx = [0]     #source indexes not to remove from this node
266       
267        if(parent.obj is opcon.WhereCon):
268            #add this where node to parent node
269
270            parent_parent = parent.source[0]
271
272            old_sourcenr = len(parent.source)
273            #add the necessary sources,targets and segments
274            parent.source = parent.source + [parent_parent] * len(node.source[1:])
275            parent_parent.target = parent_parent.target + [parent] * len(node.source[1:])
276            nsegments = [segment.lazy_copy() for segment in node.segments[1:]]
277           
278            #copied segments need new source nr and
279            for pos,segment in enumerate(nsegments):
280                segment.source = pos + old_sourcenr
281
282            parent.segments = parent.segments + nsegments
283           
284            #copy field calculation directives upward (need them a node earlier)
285            self.moveUpCalcFields(node,parent)
286            self.moveUpCalcFields(parent,parent_parent)
287
288        elif(issubclass(parent.obj,opcon.JoinOpCon)):
289            for pos,segment in enumerate(node.segments[1:]): 
290                assert isinstance(segment,SegmentInFix),"Expected only SegmentInfix-es"
291               
292                constraint_fieldid = getLocalId(segment.attr[0],self.fieldid_source,parent)
293                if(parent.out_calc is None or not constraint_fieldid in parent.out_calc):
294                    for segment in parent.segments:
295                        assert isinstance(segment,SegmentSet),"Expected only SegmentSets"
296                        if(constraint_fieldid in segment.attr):
297                            sourceidx = segment.source
298                            source = parent.source[sourceidx]
299
300                            newnode = Node(opcon.WhereCon,[source,source])
301                            newnode.target = [parent]
302                            newnode.fields = source.fields
303                            nid = [getLocalId(constraint_fieldid,self.fieldid_source,source)]
304                            newnode.segments = [itypes_py.SegmentAll(0),itypes_py.SegmentInFix(1,nid)]
305                           
306                            parent.source[sourceidx] = newnode
307                            source.target[source.target.index(parent)] = newnode
308                            source.target.append(newnode)
309                            self.cur_round.appendleft(newnode)
310                            break
311                    else:
312                        raise RuntimeError, "Calc field not found in source of join"
313                else:
314                    not_remove_idx.append(pos+1)
315        else:
316            return self.NEXT_ROUND
317
318        if(len(not_remove_idx) == len(node.source)):
319            return self.NEXT_ROUND
320        #remove conditions from where node
321        elif(len(not_remove_idx) > 1):
322            removenr = len(node.source) - len(not_remove_idx)
323            node.source = [source for pos,source in enumerate(node.source) \
324                                    if pos in not_remove_idx]
325
326            nsegments = [segment.lazy_copy() for segment in node.segments[1:] \
327                                    if segment.source in not_remove_idx]
328
329            for pos,segment in enumerate(nsegments):
330                segment.source = pos + 1
331            node.segments = [node.segments[0]] + nsegments
332
333            while(removenr > 0):
334                del parent.target[parent.target.index(node)]
335                removenr -= 1
336            return self.NEXT_ROUND | self.CHANGE
337        #or remove whole where node
338        else:
339            self.moveUpCalcFields(node,parent)
340            parent.target = node.target
341            for t in node.target:
342                for pos,source in enumerate(t.source):
343                    if(source is node):
344                        t.source[pos] = parent
345            node.target = []
346            node.source = [parent]
347            self.deleted_nodes.add(node)
348            return self.CHANGE
349        #}}}
350
351    def visitJoinWhereCon(self,node):
352        return 0
353
354    def visitnode(self,node):
355        return 0
356   
357
358
359    def visitScalarSrcCon(self,node):#{{{
360       """Attempts to remove scalar, by converting it to calculated field"""
361
362       assert (len(node.fields) == 1),"More than one field in scalar"
363     
364       #fixme: geen calc support
365       if(not node.out_calc is None):
366           return 0
367       
368       #get scalar field  (getting just the first element of set without
369       #modifying is rather hard...)
370       field = [self.fieldid_map[fieldid] for fieldid in node.fields][0]
371       
372       all_removed = True
373       any_removed = False
374       for target in node.target:
375          if(issubclass(target.obj,opcon.MergeCon)):
376            #add as calcfield to merge operation,remove scalar source
377           
378            any_removed = True
379            #removing source
380            pos = target.source.index(node)
381            rsegment = self.removeSource(target,pos)
382            nfieldid = list(rsegment.attr)[0]
383           
384            calcfield = (nfieldid,- (node.source[0] + 1),None,"SET",field.type)
385            self.addCalcFields(target,{nfieldid:calcfield})
386            target.fields.discard(nfieldid)
387           
388            #lets take a look at target again in the next round
389            self.next_round.append(target)
390
391          elif(issubclass(target.obj,opcon.BroadcastOpCon)):
392            #add as param source to broadcast opcon
393            any_removed = True
394            pos = target.source.index(node)
395
396            #modify source of broadcast to source of scalar
397            target.source[pos] = node.source[0]
398
399            #modify broadcastopcon segment to set scalar source to None
400            target.segments = [target.segments[0].lazy_copy()]
401            csource = list(target.segments[0].source)
402            csource[pos] = None
403            target.segments[0].source = tuple(csource)
404
405            self.next_round.append(target)
406          else:
407            all_removed = False
408       
409       if(all_removed):
410           node.source = []
411           node.target = []
412           self.deleted_nodes.add(node)
413           return self.CHANGE
414           
415       if(any_removed):
416           return self.CHANGE | self.NEXT_ROUND
417       else:
418           return self.NEXT_ROUND#}}}
419
420    def visitBroadcastOpCon(self,node):
421        (left_source,right_source) = node.source[:2]
422       
423        op = self.params[node.source[2]] 
424        if(left_source is right_source):
425            calcfields = []
426            for fieldid in node.fields:
427                field = self.fieldid_map[fieldid]
428                calcfields.append((fieldid,field.sourceids[0],field.sourceids[1],op,field.type))
429                del self.fieldid_source[fieldid]
430            left_source.target.remove(node)
431            right_source.target.remove(node)
432        elif(isinstance(right_source,int)):
433            if(isinstance(left_source,int)):
434                return 0
435            else:
436                calcfields = []
437            for fieldid in node.fields:
438                field = self.fieldid_map[fieldid]
439                calcfields.append((fieldid,field.sourceids[0],-(right_source + 1),op,field.type))
440                del self.fieldid_source[fieldid]
441            left_source.target.remove(node)
442        else:
443            return self.NEXT_ROUND
444
445
446        self.moveUpCalcFields(node,left_source)
447        self.addCalcFields(left_source,dict([(calcinfo[0],calcinfo) for calcinfo in calcfields]))
448           
449        for target in node.target:
450            for tsource_idx,tsource in enumerate(target.source):
451                if(tsource is node):
452                    target.source[tsource_idx] = left_source
453                    left_source.target.append(target)
454
455        self.deleted_nodes.add(node)
456        return self.CHANGE
457
458    def visitMergeCon(self,node):
459        res = self.moveUpSetNodeCalcFields(node)
460
461        if(len(set(node.source)) == 1):
462           #remove merge node
463           assert (node.out_calc is None),"Calc records in removable merge node found"
464           parent = node.source[0]
465
466           for target in node.target:
467                poss = [pos for pos,source in enumerate(target.source) if source is node]
468                for pos in poss:
469                    target.source[pos] = parent
470                    parent.target.append(target)
471               
472                self.next_round.append(target)
473           
474           for i in range(len(node.source)):
475               del parent.target[parent.target.index(node)]
476           self.next_round.append(parent)
477           node.source = []
478           node.target = []
479           self.deleted_nodes.add(node)
480           return self.CHANGE
481        else:
482           return res
483   
484    def visitEquiJoinOpCon(self,node):
485        res = self.moveUpSetNodeCalcFields(node)
486        for target in node.target:
487            self.next_round.append(target)
488        return res
489
490    def visitJoinOpCon(self,node):#{{{
491       
492        self.moveUpSetNodeCalcFields(node)
493        op = self.params[node.source[2]]
494        target_set = set(node.target)
495
496        if(op == "CROSS JOIN"):
497            node.obj = opcon.CrossJoinOpCon
498            return self.NEXT_ROUND
499
500        #if there are multiple 'users' of this join we cannot easily modify it.
501        #here we check if we should split it
502        if(len(target_set) > 1):
503            return self.NEXT_ROUND
504
505        target = target_set.pop()
506       
507        #the target should be a where op, and not be
508        #a row-based where operation (no fields in constraint segment)
509        #which cannot be used as join-condition
510        if(not issubclass(target.obj,opcon.WhereCon) or len(target.segments) == 1):
511            raise RuntimeError, "Could not find join-condition for " + op
512         
513        #join conditions for outer joins should be specified as being the join condition
514        if(op == "LEFT JOIN" or op == "RIGHT JOIN" or op == "FULL JOIN"):
515            if(not issubclass(target.obj,opcon.JoinWhereCon)):
516                raise RuntimeError, "Could not find join-condition for " + op
517       
518        #where conditions calculations should be done in the join node
519        #if there are none we are sure that there is no join condition
520        if(node.out_calc is None):
521             raise RuntimeError, "Could not find join-condition for " + op
522       
523        calc_here = node.out_calc
524        conditions = [getLocalId(segment.attr[0],self.fieldid_source,node) for segment in target.segments[1:]]
525
526        req_fields = []
527        if(len(conditions) == 1):
528            condition = self.build_condition(node,calc_here,conditions[0],req_fields)
529            (p1,opfield,p2,type) = condition
530            if(op == "INNER JOIN" \
531                and opfield == "__eq__" and \
532                not isinstance(p1,tuple) and not isinstance(p2,tuple)):
533                node.obj = opcon.EquiJoinOpCon
534        else:
535            left = self.build_condition(node,calc_here,conditions[0],req_fields)
536            right= self.build_condition(node,calc_here,conditions[1],req_fields)
537            condition = (left,"__and__",right,itypes_py.createType("bool"))
538            for cond in conditions[2:]:
539                left = cls.build_condition(calc_here,cond,req_fields)
540                condition = (left,"__and__",condition,itypes_py.createType("bool"))
541
542        node.exec_params = condition
543        node.req_fields = req_fields
544
545        #every condition of the where node has to be removed, so we
546        #will remove it completely
547        self.moveUpCalcFields(target,node)
548        node.target = target.target
549        #update targets to point to join instead
550        for tt in target.target:
551            for pos,source in enumerate(tt.source):
552                if(source is target):
553                    tt.source[pos] = node
554        self.deleted_nodes.add(target)
555        return self.CHANGE
556
557    def build_condition(self,node,calc_here,cond,req_fields):
558       (id,param1,param2,op,type) = calc_here[cond]
559       if(param1 >= 0):
560            param1 = getLocalId(param1,self.fieldid_source,node)
561            if(param1 in calc_here):
562                 param1 = self.build_condition(node,calc_here,param1,req_fields)
563            else:
564                 req_fields.append(param1)
565                 param1 = len(req_fields) - 1
566
567       if(param2 >= 0):
568            param2 = getLocalId(param2,self.fieldid_source,node)
569            if(param2 in calc_here):
570                param2 = self.build_condition(node,calc_here,param2,req_fields)
571            else:
572                req_fields.append(param2)
573                param2 = len(req_fields) - 1
574
575       return (param1,op,param2,type)#}}}
576
577#}}}
578
579#UsedFieldInferencePass: Determines for each node the fields that #{{{
580#have to be obtained from the source
581class UsedFieldInferencePass(Pass): 
582   
583    after=(qgraph_py.PostOrderWalkPass,qgraph_py.FieldIdMapPass,\
584            ConsolidationPass)   
585
586    @classmethod
587    def run(cls,query):
588        """Determine for each the fields that have to be obtained from the source"""
589
590        depwalk = query.pass_results[qgraph_py.PostOrderWalkPass]
591        (fieldid_map,fieldid_source) = query.pass_results[qgraph_py.FieldIdMapPass]
592
593        root = depwalk[0]
594        res_in = [root.segments[0].attr]
595        root.out_fields,root.in_fields = cls.inferOutInFields(root,res_in,None,fieldid_map,fieldid_source)
596       
597        for node in depwalk[1:]: #children first walk (exclude root node)
598            #STEP 1: infer required fields in this node
599
600            #obtain collection of all in_fields of target nodes
601            in_fields = [[target.in_fields[tsource_idx]\
602                                for tsource_idx,tsource in enumerate(target.source)\
603                                if tsource is node]\
604                            for target in set(node.target)]
605           
606            #combine in_field arrays of multiple targets
607            in_fields = reduce(operator.add,in_fields)
608           
609            #longest first (that way we do not have to reorder its input)
610            in_fields = sorted(in_fields,key=len,reverse=True)
611           
612            #flatten the in_field arrays
613            in_fields = reduce(operator.add,in_fields)
614            if(not node.req_fields is None):
615                in_fields = in_fields + node.req_fields
616       
617            #uniqify the list while keeping it ordered
618            req_field_ids_set = set()
619            req_field_ids = [field_id for field_id in in_fields\
620                            if field_id not in req_field_ids_set and \
621                            not req_field_ids_set.add(field_id)]
622
623            if(not node.out_calc is None):
624                calc_info = node.out_calc
625                #determine which of the calculated fields are in the output
626                req_calc_output_set = set(calc_info.keys()) & req_field_ids_set
627
628           
629                if(not req_calc_output_set):
630                    #if no calculated fields are required, set to None
631                    req_calc_output_set = None #calc fields in output
632                    node.out_calc = None   #fields to calculate here (including dependencies output)
633                else:
634                    #determine calculated output field dependencies
635                    req_calc_here_set = req_calc_output_set.copy() #make copy for queue uniqueness
636
637                    #walk through dependencies and add
638                    #to req_field_ids if not calculated here
639                    #otherwise, add to req_calc_here
640                    calc_queue = list(req_calc_output_set)
641                   
642                    nout_calc = {}
643                    while(calc_queue):
644                        c_id = calc_queue.pop()
645                        (id,param1,param2,op,type) = calc_info[c_id]
646                        if(param1 >= 0):
647                            param1 = getLocalId(param1,fieldid_source,node)
648                            if(param1 in calc_info and not param1 in req_calc_here_set):
649                                calc_queue.append(param1)
650                                req_calc_here_set.add(param1)
651                            elif(param1 not in req_field_ids_set):
652                                req_field_ids.append(param1)
653                                req_field_ids_set.add(param1)
654                        if(param2 >= 0):
655                            param2 = getLocalId(param2,fieldid_source,node)
656                            if(param2 in calc_info and not param2 in req_calc_here_set):
657                                calc_queue.append(param2)
658                                req_calc_here_set.add(param2)
659                            elif(param2 not in req_field_ids_set):
660                                req_field_ids.append(param2)
661                                req_field_ids_set.add(param2)
662                       
663                        nout_calc[id] = (id,param1,param2,op,type)
664                    node.out_calc = nout_calc 
665            else:
666                req_calc_output_set = None
667                node.out_calc = None
668           
669           
670            node.out_fields,node.in_fields = cls.inferOutInFields(node,req_field_ids,req_calc_output_set,fieldid_map,fieldid_source)
671        return (None,None,None)
672
673    @classmethod
674    def inferOutInFields(cls,node,req_field_ids,req_calc_out_ids,fieldid_map,fieldid_source):
675        in_field_ids_col = [None] * len(node.source)
676        out_field_ids = []
677
678        #CALC SEGMENT (all nodes have one by default)
679        if(not req_calc_out_ids is None):
680            out_field_ids.extend(req_calc_out_ids)
681            #remove field that are calced here from request list
682            req_field_ids = [rfi for rfi in req_field_ids if not rfi in req_calc_out_ids]
683        for segment in node.segments:
684           
685            if(isinstance(segment,SegmentAll)):
686                res_out = res_in = req_field_ids
687                out_field_ids.extend(res_out)
688                in_field_ids_col[segment.source] = getLocalId(res_in,fieldid_source,node.source[segment.source])
689            elif(isinstance(segment,SegmentSet)):
690                cur_field_set = segment.attr
691                res_out = res_in =  [req_id for req_id in req_field_ids\
692                                        if req_id in cur_field_set]
693                out_field_ids.extend(res_out)
694                in_field_ids_col[segment.source] = getLocalId(res_in,fieldid_source,node.source[segment.source])
695            elif(isinstance(segment,SegmentInFix)):
696                res_in = segment.attr
697                in_field_ids_col[segment.source] = getLocalId(res_in,fieldid_source,node.source[segment.source])
698            elif(isinstance(segment,SegmentOutFix)):
699                out_field_ids.extend(segment.attr)
700            elif(isinstance(segment,SegmentFix)):
701                res_in = segment.attr
702                in_field_ids_col[segment.source] = getLocalId(res_in,fieldid_source,node.source[segment.source])
703                out_field_ids.extend(segment.attr)
704            elif(isinstance(segment,SegmentFields)):
705                #if no segment_attr consider all fields with sourceids to be part of
706                #this reduce segment
707                if(segment.attr is None):
708                    res = [rid for rid in req_field_ids\
709                                if rid in fieldid_source]
710                #else use all fields that have their id in the segment_attr
711                else: 
712                    reduce_field_ids = set(segment.attr)
713                    res = [rid for rid in req_field_ids \
714                                    if rid in reduce_field_ids]
715
716                #results with None field id's are not used
717                for operand_idx,source_idx in enumerate(segment.source):
718                    if(source_idx is None):
719                        continue
720                    res_in = [fieldid_source[rid][operand_idx] for rid in res]
721
722                    if(not None in res_in):
723                        in_field_ids_col[source_idx] = getLocalId(res_in,fieldid_source,node.source[source_idx])
724               
725                out_field_ids.extend(res)
726
727        return (out_field_ids,in_field_ids_col)
728   
729       
730
731
732#CalcFieldTransformPass: Calculates for each node the relative indexes#{{{
733#of the needed fields
734class CalcFieldTransformPass(Pass): 
735    after=(UsedFieldInferencePass,qgraph_py.PostOrderWalkPass,qgraph_py.FieldIdMapPass,\
736        PeepHoleOptimizerPass)
737
738    @classmethod
739    def run(cls,query):
740        """Calculate for each node the relative indexes of the required fields"""
741
742        depwalk = query.pass_results[qgraph_py.PostOrderWalkPass]
743        (fieldid_map,fieldid_source) = query.pass_results[qgraph_py.FieldIdMapPass]
744       
745        #first pass: determine out field positions, node.fields reset to real field objects
746        for node in depwalk:
747            outcalc_count = 0
748            if(node.out_calc is None):
749                node.fields = [fieldid_map[oid] for oid in node.out_fields]
750                node.out_fields = dict([(oid,index) for index,oid in enumerate(node.out_fields)])
751            else:
752                node.fields = [fieldid_map[oid] for oid in node.out_fields if oid not in node.out_calc]
753                node.out_fields = dict([(oid,index) for index,oid in enumerate(node.out_fields)])
754                req_calc_here = node.out_calc.values()
755                req_calc_here.sort(key=operator.itemgetter(0))
756                req_calc_here_pos = dict([(calcinfo[0],pos) for pos,calcinfo in enumerate(req_calc_here)])
757
758                req_calc = [0]
759                for pos,(id,param1,param2,op,type) in enumerate(req_calc_here):
760                    if(param1 >= 0):
761                        #param1 = getLocalId(param1,fieldid_source,node)
762                        if(param1 in node.out_fields):
763                            p1 = node.out_fields[param1]
764                        else:
765                            p1 = len(node.out_fields) + req_calc_here_pos[param1]
766                    else:
767                        p1 = param1
768                       
769                    if(not param2 is None and param2 >= 0):
770                        #param2 = getLocalId(param2,fieldid_source,node)
771                        if(param2 in node.out_fields):
772                            p2 = node.out_fields[param2]
773                        else:
774                            p2 = len(node.out_fields) + req_calc_here_pos[param2]
775                    else:
776                        p2 = param2
777
778                    if(id in node.out_fields):
779                        outpos = node.out_fields[id]
780                        outcalc_count += 1
781                    else:
782                        outpos = -1
783                    req_calc.append((p1,p2,op,type,outpos))
784                req_calc[0] = outcalc_count
785                if(outcalc_count != 0):
786                    node.out_calc = tuple(req_calc)
787                else:
788                    node.out_calc = None
789            if(node.req_fields):
790                #subtract nr. of calcfields as req_fields is used internally in the node
791                #before the calcfields are added
792                node.req_fields = [node.out_fields[id] - outcalc_count for id in node.req_fields]
793
794        for node in depwalk:
795            for source_idx,source in enumerate(node.source):
796                if(not isinstance(source,Node)):
797                    continue
798                in_field_ids = node.in_fields[source_idx]
799                out_field_idx = source.out_fields
800                res = [out_field_idx[in_id] for in_id in in_field_ids]
801
802                #Now simplify transformation if possible
803                if(len(res) == 1): #single fields dont need list
804                    #if source is also single field, no trans needed
805                    if(len(out_field_idx) == 1): 
806                        res = None
807                    else:
808                        res = res[0]
809                       
810                #check if strictly ordered (i.e. 1,2,3,4 is strictly orderded,
811                #, 1,3,4 is not)
812                elif(len(res) > 0 and all([(r - l) == 1 for (l,r) in zip(res[:-1],res[1:])])):
813                    #if length equal to out fields, no trans is needed
814                    if(len(res) == len(out_field_idx)):
815                         res = None
816                    #otherwise, we can use a slice index (faster)
817                    else:
818                         res = slice(res[0],res[-1] + 1)
819               
820                node.in_fields[source_idx] = res
821           
822            node.out_fields = None
823            if(all([in_f is None for in_f in node.in_fields])):
824                node.in_fields = None
825
826               
827
828        return (None,(UsedFieldInferencePass,),None)
829#}}}
Note: See TracBrowser for help on using the repository browser.