Changeset 57


Ignore:
Timestamp:
May 25, 2009, 1:38:28 PM (11 years ago)
Author:
marchulsman
Message:

DB query support

Location:
container
Files:
8 edited

Legend:

Unmodified
Added
Removed
  • container/SQL_pass.py

    r56 r57  
    11import basic_pass
    22import qg_rewrite
    3 
    4 
     3from qgraph import NodeVisitorFactory,Node
     4from multi_visitor import NF_ELSE,NF_ROBJ,F_CACHE
     5from base_container import QueryResult
     6import qgraph_py
     7import engine
     8import utility
     9import xnumpy
     10import itypes_py
     11import numpy
     12
     13tableseqgen = utility.seqgen().next#{{{
    514class SQLQuery(object):
    6     __slots__ = ['FROM','WHERE','FIELDS']
    7 
    8 class SQL_pass(basic_pass.Pass):
     15    __slots__ = ['table','cond','fields','fieldtypes','conn','alias']
     16
     17    @classmethod
     18    def create(cls,tablename,fieldnames,fieldtypes,conn):
     19        self = cls()
     20        alias = self.new_alias(tablename)
     21
     22        self.alias = set([alias])
     23        self.table = tablename + " AS " + alias
     24        self.fields = xnumpy.dimarray([alias + "." + fn for fn in fieldnames],object,1,1)
     25        self.fieldtypes = fieldtypes
     26        self.conn = conn
     27        self.cond = []
     28        return self
     29
     30    @classmethod
     31    def merge(self,*sqlquerys):
     32        if(not sqlquerys):
     33            return False
     34
     35        conn = sqlquerys[0].conn
     36        table = sqlquerys[0].table
     37        fields = []
     38        fieldtypes = []
     39        cond = set()
     40        alias = set()
     41        for sq in sqlquerys:
     42            if(not (sq.conn is conn and sq.table == table)):
     43                return False
     44            fields.append(sq.fields)
     45            fieldtypes.append(sq.fieldtypes)
     46            for c in sq.cond:
     47                cond.add(c)
     48            alias += sq.alias
     49       
     50        n = SQLQuery()
     51        n.table = table
     52        n.conn = conn
     53        n.fields = numpy.hstack(fields)
     54        n.fieldtypes = numpy.hstack(fieldtypes)
     55        n.cond = list(cond)
     56        n.alias = alias
     57        return n
     58   
     59    @classmethod
     60    def createParam(self,idx):
     61        return '%(val' + str(idx) + ')s'
     62
     63    def setActiveFields(self,sel):
     64        self.fields = self.fields[sel]
     65        self.fieldtypes = self.fieldtypes[sel]
     66        return self
     67
     68    def copy(self):
     69        n = SQLQuery()
     70        n.table = self.table
     71        n.fields = self.fields
     72        n.fieldtypes = self.fieldtypes
     73        n.conn = self.conn
     74        n.cond = list(self.cond)
     75        n.alias = self.alias.copy()
     76        return n
     77
     78    def new_alias(self,name):
     79        if(" " in name or len(name) > 64):
     80            return "realias" + str(tableseqgen()) + "S"
     81        else:
     82            return name + str(tableseqgen()) + "S"
     83
     84    def realias(self,alias):
     85        nalias = new_alias(self,alias)
     86        for i,field in enumerate(self.fields):
     87            self.fields[i] = self.fields[i].replace(alias,nalias)
     88   
     89        for i,cond in enumerate(self.cond):
     90            self.cond[i] = self.cond[i].replace(alias,nalias)
     91
     92        self.tables = self.tables.replace(alias,nalias)
     93        self.alias.remove(alias)
     94        self.alias.add(nalias)
     95
     96    def __str__(self):
     97        sel = ", ".join(self.fields)
     98        sql = "SELECT " + sel + " FROM " + self.table
     99        if(self.cond):
     100            sql = sql + " WHERE " + " AND ".join(self.cond)
     101        return sql       
     102
     103    def executeQuery(self,cursor,params):
     104        sel = ", ".join(self.fields)
     105        pdict = dict([("val" + str(i),p)  for i,p in enumerate(params)])
     106        sql = "SELECT " + sel + " FROM " + self.table
     107        if(self.cond):
     108            sql = sql + " WHERE " + " AND ".join(self.cond)
     109        cursor.execute(sql,pdict)
     110        print cursor.query
     111#}}}
     112
     113
     114class SQL_pass(NodeVisitorFactory(flags=F_CACHE | NF_ROBJ),basic_pass.Pass):
    9115    @classmethod
    10116    def ready(cls,pass_results):
    11117        return qg_rewrite.FieldTransformPass in pass_results
     118
     119    @classmethod
     120    def invalidates(cls):
     121        return (qg_rewrite.DependencyWalkPass,)
     122
    12123   
    13124    @classmethod
    14     def run(self,query):
     125    def run(cls,query):
    15126        self = cls(query.params)
    16127        self.visit(query.qg)
    17128
    18129    def visitDBSrcCon(self,node):
     130        tablename = node.props[node.props.name == 'tablename'].value.item()
     131        conn = node.props[node.props.name == 'conn'].value.item()
     132        fieldnames = node.fields.name
     133        fieldtypes = node.fields.type
     134
     135        x = self.outtrans(node,SQLQuery.create(tablename,fieldnames,fieldtypes,conn))
     136        return x
     137 
     138
     139    binop_translate = {'__add__':'+','__mul__':'*','__and__':' AND ','__or__':' OR ','__eq__':'=','__le__':'<=','__ge__':'>=','__lt__':'<','__gt__':'>','__ne__':'!='}
     140    def visitArithmeticOpCon(self,node):
     141        (l,r,op) = self.visitsources(node)
     142        if(isinstance(l,SQLQuery)):
     143            if(isinstance(r,SQLQuery)):
     144                if(not (l.table == r.table and \
     145                        l.cond == r.cond and \
     146                        l.conn is r.conn)):
     147                     self.replaceSourceNode(node,0,l)
     148                     self.replaceSourceNode(node,1,r)
     149                     return None
     150            elif(r is None or isinstance(r,QueryResult)):
     151                self.replaceSourceNode(node,0,l)
     152                return None
     153            else:
     154                r = SQLQuery()
     155                r.fields = xnumpy.dimarray((SQLQuery.createParam(node.source[1]),),object,1,1)
     156        else:
     157            if(isinstance(r,SQLQuery)):
     158                self.replaceSourceNode(node,1,r)
     159            return None
     160       
     161        if(len(r.fields) ==1):
     162            nfields = xnumpy.dimarray([lf + " " + self.binop_translate[op] + " " + r.fields[0] for lf in l.fields],object,1,1)
     163        else:
     164            nfields = xnumpy.dimarray([lf + " " + self.binop_translate[op] + " " + rf for lf,rf in zip(l.fields,r.fields)],object,1,1)   
     165       
     166        l.fields = nfields
     167        return self.outtrans(node,l)
     168
     169    def visitWhereCon(self,node):
     170        (s,r) = self.visitsources(node)
     171        if(isinstance(s,SQLQuery)):
     172            if(isinstance(r,SQLQuery)):
     173                if(not (s.table == r.table and \
     174                        s.cond == r.cond and \
     175                        s.conn is r.conn)):
     176                     self.replaceSourceNode(node,0,s)
     177                     self.replaceSourceNode(node,1,r)
     178                     return None
     179                s.cond.append(r.fields[0])
     180                return self.outtrans(node,s)
     181            else:
     182                self.replaceSourceNode(node,0,s)
     183        else:
     184            if(isinstance(r,SQLQuery)):
     185                self.replaceSourceNode(node,1,r)
     186
     187
     188    def visitCrossJoinOpCon(self,node):
     189        (l,r,op) = self.visitsources(node)
     190        if(isinstance(l,SQLQuery)):
     191            if(isinstance(r,SQLQuery)):
     192                if(op == "CROSS JOIN"):
     193                    intersect_alias = l.alias & r.alias
     194                    for ia in intersect_alias:
     195                        r.realias(ia)
     196                    l.table = l.table + " " + op + " " + r.table
     197                    l.fields = numpy.hstack((l.fields,r.fields))
     198                    l.fieldtypes = numpy.hstack((l.fieldtypes,r.fieldtypes))
     199                    l.cond.extend(r.cond)
     200                    l.alias = l.alias | r.alias
     201
     202                    return self.outtrans(node,l)
     203                else:
     204                     self.replaceSourceNode(node,0,l)
     205                     self.replaceSourceNode(node,1,r)
     206            else:
     207               self.replaceSourceNode(node,0,l)
     208        else:
     209            if(isinstance(r,SQLQuery)):
     210                self.replaceSourceNode(node,1,r)
     211
     212
     213    def outtrans(self,node,result):
     214        if(not node.out_trans is None):
     215            nres = []
     216            for i,otrans in enumerate(node.out_trans[2:]):
     217                if(not otrans is None):
     218                    nres.append(res[i].copy().setActiveFields(otrans))
     219            if(node.out_trans[0] is None):
     220                nres.append(result)
     221            else:
     222                nres.append(result.setActiveFields(node.out_trans[0]))
     223            print nres
     224
     225            if(len(nres) == 1):
     226                result = nres[0]
     227            else:
     228                result = SQLObject.merge(nres)
     229                if(res is False): #merging failed
     230                    for idx,s in enumerate(node.source):
     231                        if(isinstance(s,Node)):
     232                            self.replaceSourceNode(node,idx,self.visit(node))
     233                    return None
     234        return result
     235
     236    def visit(self,node):
     237        result = super(SQL_pass,self).visit(node)
     238       
     239        if(not result is None and isinstance(node,Node) and len(node.target) > 1):
     240            result = result.copy()
     241        return result
     242
     243    def visitsources(self,node):
     244        res = super(SQL_pass,self).visitsources(node)
     245        if(not node.in_trans is None):
     246            inres = list(res)
     247            for i,a_idx in enumerate(node.in_trans):
     248                if(not a_idx is None and isinstance(inres[i],SQLQuery)):
     249                    inres[i].setActiveFields(a_idx)
     250            return tuple(inres)
     251        return res
     252
     253    def visitnode(self,node):
     254        for idx,s in enumerate(node.source):
     255            if(isinstance(s,Node)):
     256                r = self.visit(s)
     257                print s.obj,r
     258                if(not r is None):
     259                    self.replaceSourceNode(node,idx,r)
     260               
     261
     262    def replaceSourceNode(self,curnode,sidx,sqlobj):
     263        source = [sqlobj]
     264        n = Node(SQLToPyCon,source)
     265        s = curnode.source[sidx]
     266        del s.target[s.target.index(curnode)]
     267        curnode.source[sidx] = n
     268
     269
     270class SQLToPyCon(qgraph_py.ExtendCon):
     271    @classmethod
     272    def execute(self,visitor,node):
     273        (sqlobj,) = node.source
     274        ftypes = sqlobj.fieldtypes
     275        conn = sqlobj.conn
     276
     277        cursor = conn.getSingletonCursor()
     278        sqlobj.executeQuery(cursor,visitor.params)
     279        res = cursor.fetchall()
     280        ftypes = itypes_py.to_numpy(utility.ensure_seq(ftypes))
     281        if(len(res) > 0):
     282            res = itypes_py.transpose(res)
     283            data = xnumpy.dimarray(tuple((xnumpy.dimarray(col,ftype) for col,ftype in zip(res,ftypes))),object,1,1)
     284            nrow = len(res[0])
     285            ncol = len(res)
     286            return QueryResult(data,nrow,ncol)
     287        else:
     288            data = xnumpy.dimarray(tuple(numpy.empty((0,),dtype=ft)for ft in ftypes),object,1,1)
     289            return QueryResult(data,0,len(ftypes))
     290
     291engine.e.pre_pm.register(SQL_pass)
     292engine.e.pre_pm.register(qg_rewrite.DrawFieldTreeVisitor)
     293
     294
     295
  • container/engine.py

    r56 r57  
    3939
    4040
     41import SQL_pass
  • container/opcon.py

    r56 r57  
    3939            props.fields = map(numpy.abs,props.fields)
    4040            for i,f in enumerate(props.fields):
    41                 props.fields[i] = f * v
     41                if(len(f) > 0):
     42                    props.fields[i] = f * v
    4243            self._finishProps(props)
    4344   
  • container/postgres.py

    r56 r57  
    1010from base_container import QueryResult
    1111import basic_pass
    12 
    1312
    1413class LogCursor(psycopg2.extensions.cursor):
  • container/qg_execute.py

    r56 r57  
    3030                return QueryResult(data=data,nrow=len(data[0]),ncol=len(data))
    3131        return result
     32
    3233
    3334    def visitsources(self,node):
     
    7172        return QueryResult(data,lr.nrow,ncol)
    7273
    73     def visitSrcCon(self,node):
     74    def visitPySrcCon(self,node):
    7475        (res,) = self.visitsources(node)
    7576        return res
     
    7879        (res,) = self.visitsources(node)
    7980        return res
    80    
     81   
     82    def visitExtendCon(self,node):
     83        return node.obj.execute(self,node)
    8184#}}}
    8285
  • container/qg_rewrite.py

    r56 r57  
    136136                    tatype = ta.type
    137137                    s = n.source[ta.source]
     138                    if(not isinstance(s,Node)):
     139                        continue
    138140                    sfield = ta.start_field
    139141                    efield = ta.end_field
     
    228230           if(efield == -1):
    229231               efield = sfield + len(s.fields)
    230            if(s.out_trans is None):
    231                s.out_trans = numpy.zeros((len(s.fields),),dtype=bool)
    232 
    233            if(tatype & itypes_py.TA_FLEX):
    234                s.out_trans |= node.out_trans[sfield:efield]
    235                if(tatype & itypes_py.TA_LEASTONE):
    236                     s.tmp = True #tmp variable indicating this node should at least
    237                                  #return one field
    238            elif(tatype & itypes_py.TA_REDUCE):
    239                nactidx = s.actidx[node.out_trans]
    240                s.out_trans[nactidx] = True
    241                node.in_trans[ta.source] = nactidx
    242            elif(tatype & itypes_py.TA_FIX):
    243                s.out_trans[s.actidx] = True
    244                node.in_trans[ta.source] = s.actidx
    245                node.out_trans[sfield:efield] = True
    246            elif(tatype & itypes_py.TA_PT):
    247                s.out_trans |= node.out_trans[sfield:efield]
    248                node.out_trans[sfield:efield] = False
    249            else:
    250                raise RuntimeError, "Unexpected table attribute type"
     232           if(isinstance(s,Node)):
     233              if(s.out_trans is None):
     234                  s.out_trans = numpy.zeros((len(s.fields),),dtype=bool)
     235
     236              if(tatype & itypes_py.TA_FLEX):
     237                  s.out_trans |= node.out_trans[sfield:efield]
     238                  if(tatype & itypes_py.TA_LEASTONE):
     239                          s.tmp = True #tmp variable indicating this node should at least
     240                                      #return one field
     241              elif(tatype & itypes_py.TA_REDUCE):
     242                  nactidx = s.actidx[node.out_trans]
     243                  s.out_trans[nactidx] = True
     244                  node.in_trans[ta.source] = nactidx
     245              elif(tatype & itypes_py.TA_FIX):
     246                  s.out_trans[s.actidx] = True
     247                  node.in_trans[ta.source] = s.actidx
     248                  node.out_trans[sfield:efield] = True
     249              elif(tatype & itypes_py.TA_PT):
     250                  s.out_trans |= node.out_trans[sfield:efield]
     251                  node.out_trans[sfield:efield] = False
     252              else:
     253                  raise RuntimeError, "Unexpected table attribute type"
    251254        return True
    252255
  • container/qgraph_py.py

    r56 r57  
    66
    77#special root class, used as root for query graph
    8 class Root(opcon.UnaryOpCon):
     8class ExtendCon(opcon.OpCon):
     9    pass
     10
     11class Root(opcon.OpCon):
    912    pass
    1013
  • container/srccon.py

    r56 r57  
    1111import xnumpy
    1212
    13 
     13import capcon
    1414class SrcCon(container.Container):
    1515    """SrcCon: main class for source containers"""
     
    2424        self._fields = fields
    2525        self._actidx = numpy.arange(len(fields)).view(xnumpy.XArray)
    26         self._props = itypes_py.createProps(pname=  (tablename    ,""),
    27                                   ptype=  ("tablename"  ,"conn"),
     26        self._props = itypes_py.createProps(pname=  ("tablename","conn"),
     27                                  ptype=  (capcon.Property  ,capcon.Property),
    2828                                  pfields = None,
    29                                   pvalue = (None,conn)
     29                                  pvalue = (tablename,conn)
    3030                                  )
    3131        self._invar = container.Invar()
Note: See TracChangeset for help on using the changeset viewer.