get_cousin_bfs

# created by bohuai jiang
# on 2019/7/23
# last modified on 2019/9/17 10:14
# -*- coding: utf-8 -*-

from sqlparse.sql import Statement, Comment, Where, Identifier, IdentifierList, Parenthesis, Function, \
    Comparison, Operation, Token, TokenList, Values
from sqlparse import tokens as T
from typing import Union, List, Tuple, Optional, Set


class ParseUnit:
    def __init__(self):
        self.id = None
        self._name = None  # sql code name
        self._as_name = None  # as what name
        self._from_name = None  # from where
        self._type = None  # TAB-table , COL-column, SUB-subquery ,OPT- >,<,=.., FUNC-MAX,SUM..
        self._keyword = None
        self._in_statement = 'OTHER'

        # self._opt = None
        self._parent = set()
        self._edges = set()
        self._level = 0
        self.token = None

    @property
    def in_statement(self) -> str:
        return self._in_statement

    @property
    def level(self) -> int:
        return self._level

    @property
    def keyword(self) -> str:
        return self._keyword

    @property
    def name(self) -> str:
        return self._name

    @property
    def as_name(self) -> str:
        return self._as_name

    @property
    def from_name(self) -> str:
        return self._from_name

    @property
    def parent(self) -> set:
        return self._parent

    @property
    def type(self) -> str:
        return self._type

    @property
    def edges(self) -> set:
        return self._edges

    # @property
    # def opt(self) -> str:
    #     return self._opt
    @keyword.setter
    def keyword(selfkeystr):
        self._keyword = key.upper()

    @level.setter
    def level(selflevelint):
        self._level = level

    @name.setter
    def name(selfname: Optional[str]):
        if type(name) == str:
            self._name = name.upper()
        else:
            self._name = name

    @as_name.setter
    def as_name(selfas_namestr):
        self._as_name = as_name.upper()

    @from_name.setter
    def from_name(selffrom_namestr):
        self._from_name = from_name.upper()

    @parent.setter
    def parent(selfparent: Set['ParseUnit']):
        self._parent = parent

    # @opt.setter
    # def opt(self, opt: str):
    #     self._opt = opt

    @type.setter
    def type(selftypestr):
        if type not in ['COL''TAB''SUB''OPT''FUNC''STRUCT''VALUE']:
            raise ValueError('type must be either one of following [COL, TAB, SUB, OPT, FUNC, STRUC, VALUE]')
        self._type = type.upper()

    @in_statement.setter
    def in_statement(selfstatestr):
        if state not in ['WHERE''ORTHER']:
            raise ValueError('type must be either one of following [WHERE, OTHER]')
        self._in_statement = state

    @edges.setter
    def edges(selfedges: Set['ParseUnit']):
        self._edges = edges

    def overwrite(selfunit'ParseUnit'):
        if unit.name is not None:
            self._name = unit.name
        if unit.as_name is not None:
            self._as_name = unit.as_name
        if unit.from_name is not None:
            self._from_name = unit.from_name
        if unit.parent is not None:
            self._parent = unit.parent
        if unit.type is not None:
            self._type = unit.type
        if not unit.edges:
            self._edges = unit.edges

    def inherit(selfunit'ParseUnit'update_edgesbool = False):
        self._name = unit.name
        self._as_name = unit.as_name
        if unit.from_name != 'DUMMY':
            self._from_name = unit.from_name
        self._type = unit.type
        if update_edges:
            self._edges.add(unit.id)

    def show(self) -> str:
        out = ''
        if self._from_name is not 'DUMMY' and not None:
            out += self._from_name + '.'
        out += self._name
        if self._as_name is not 'DUMMY' and not None:
            out += ' as ' + self._as_name
        return out

    def add_parents(selfparents: Union[List[int], Set[int]]) -> None:
        for p in parents:
            self._parent.add(p)

    def __repr__(self):
        out = '%s\n' % str(self.id)
        out += '\ttype:%s\n' % self.type
        out += '\tname:%s\n' % self.name
        out += '\tkeyword:%s\n' % self.keyword
        out += '\tstatement:%s\n' % self.in_statement
        out += '\tlevel:%s\n' % self.level
        out += '\tas_name:%s\n' % self.as_name
        out += '\tfrom:%s\n' % self.from_name
        out += '\tparent:%s\n' % str(self.parent)
        out += '\tedges:%s\n' % str(self.edges)

        return out

    def get_name(self)-> str:
        if self.as_name is not None or self.as_name.upper() != 'DUMMY':
            return self.as_name
        else:
            return self.name

class ParseUnitList:
    def __init__(self) -> None:
        # -- tab col relation -- #
        self.by_type = {'COL': [],
                        'TAB': [],
                        'SUB': [],
                        'OPT': [],
                        'FUNC': [],
                        'STRUCT': [],
                        'VALUE': []}
        self.by_id = dict()  # G
        self._allow_sub_has_table = False

    def __insert(selfunit: ParseUnit) -> int:
        # o(mn) m<n
        id = len(self.by_id)
        unit.id = id
        # for i, each_units in enumerate(self.by_type[unit.type]):
        #     as_name = each_units.as_name
        #     if unit.name == as_name and (unit.from_name == each_units.from_name \
        #                                  or each_units.from_name == 'DUMMY'):
        #         unit.inherit(unit=each_units, update_edges=True)
        #         each_units.inherit(unit=unit)
        #         self.by_id[each_units.id] = each_units
        #         break
        # -----#
        self.by_type[unit.type].append(unit)
        self.by_id[unit.id] = unit
        return id

    def __update_by_type(self) -> None:
        for key in ['SUB''TAB''OPT''FUNC''COL''STRUCT']:
            for unit in self.by_type[key]:
                self.by_id[unit.id] = unit

    def __update_by_id(self):
        self.by_type = {'COL': [],
                        'TAB': [],
                        'SUB': [],
                        'OPT': [],
                        'FUNC': [],
                        'STRUCT': [],
                        'VALUE': []}
        for id in self.by_id:
            unit = self.by_id[id]
            self.by_type[unit.type].append(unit)

    ########################################
    #           add  function              #
    ########################################

    # ----------- add by token type -----------#

    def _add_Identifier(selftokens: Token, typestrkeystrlevelintis_wherebool,
                        parents: List[int= None) -> Tuple[int, Union[Token, TokenList]]:
        out = ParseUnit()
        if '(' in tokens.value and tokens.value != '(':
            out.type = 'SUB'
        else:
            out.type = type
        out.keyword = key
        out.level = level
        dot_flag = 1
        out.token = tokens
        if is_where:
            out.in_statement = 'WHERE'
        if parents is not None and parents != []:
            out.add_parents(parents)

        abnormal = None
        try:
            for t in tokens:
                if str(t.ttype).upper() == 'TOKEN.PUNCTUATION' and t.value == '.':
                    dot_flag += 1
                    continue
                if str(t.ttype).upper() == 'TOKEN.NAME':
                    if dot_flag % 2 == 0:
                        out.name = t.value
                        dot_flag += 1
                    else:
                        out.from_name = t.value
                if t.ttype is None:
                    out.as_name = t.value
                    if not isinstance(t, Identifier):
                        abnormal = t
            if dot_flag <= 1:
                out.name = out.from_name
                out.from_name = 'DUMMY'
        except:
            out.name = tokens.value
        # --- double check whether used  dot --- #

        if out.as_name is None:
            out.as_name = 'DUMMY'

        # -- patch --#
        if out.name is None:
            if abnormal is not None:
                out.name = abnormal.value
            else:
                out.name = out.as_name
        # -- add  order by or group by -- #
        keyList = ['ORDER BY''GROUP BY']

        if key in keyList:
            # -- find nearest opt -- #
            for id in range(len(self.by_id))[::-1]:
                acquire_id = id
                unit = self.by_id[id]
                if unit.type == 'OPT' and unit.name == key:
                    break
            out.parent.add(acquire_id)

        # -- add to like -- #
        if key == 'LIKE':
            out.add_parents([len(self.by_id) - 1])
        id = self.__insert(out)
        return id, abnormal

    def _add_Comparison(selftokens: Token, typestrkeystrlevelintis_whereboolparents: List[int= None) \
            -> Optional[List[dict]]:
        if isinstance(tokens, Comparison):
            # -- get opt unit --#
            opt = None
            for t in tokens:
                if t.ttype == T.Operator.Comparison:
                    opt = t.value
            unit = ParseUnit()
            unit.name = opt
            unit.type = 'OPT'
            unit.keyword = key
            unit.level = level
            count = 0
            for t in tokens:
                if not t.is_whitespace:
                    count += 1
                if count == 2:
                    unit.token = t
                    break
            expect_id = len(self.by_id) + 1
            if is_where:
                unit.in_statement = 'WHERE'
            if parents is not None and parents != []:
                unit.add_parents(parents)

            # -- left unit -- #
            parents = [expect_id]
            parents_token_left = self.add(tokens=tokens.left, type=typekey=key, level=level, parents=parents,
                                          is_where=is_where)

            self.__insert(unit)

            parents_token_right = self.add(tokens=tokens.right, type=typekey=key, level=level, parents=parents,
                                           is_where=is_where)

            # unit.edges.add(left_v)
            # unit.edges.add(right_v)

            if parents_token_left and parents_token_right:
                return parents_token_left + parents_token_right
            elif parents_token_left:
                return parents_token_left
            else:
                return parents_token_right
        else:
            unit = ParseUnit()
            unit.name = tokens.value
            unit.type = 'OPT'
            unit.keyword = key
            unit.level = level
            unit.token = tokens
            current_id = len(self.by_id)
            self.__insert(unit)
            if is_where:
                unit.in_statement = 'WHERE'
            if parents is not None and parents != []:
                unit.add_parents(parents)
            unit.edges = {current_id - 1, current_id + 1}
            return None

    def _add_Operation(selftokens: Operation, typestrkeystrlevelintis_wherebool,
                       parents: List[int= None):
        unit = ParseUnit()
        unit.name = tokens.value
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        expect_id = len(self.by_id)
        self.__insert(unit)
        for t in tokens.tokens:
            self.add(tokens=t, type=typekey=key, level=level, parents=[expect_id],
                     is_where=is_where)

    def _add_Function(selftokens: Function, keystrlevelintis_whereboolparents: List[int= None) \
            -> Tuple[int, Optional[list]]:
        unit = ParseUnit()
        unit.name = tokens.tokens[0].value
        unit.type = 'FUNC'
        unit.keyword = key
        unit.level = level
        unit.token = tokens.tokens[0]
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        id = self.__insert(unit)
        return id, tokens.tokens[1::]

    def _add_Parenthesis(selftokens: Parenthesis, keystrlevelintis_whereboolparents: List[int= None) \
            -> Tuple[int, Parenthesis]:
        unit = ParseUnit()
        unit.name = tokens.value
        unit.type = 'SUB'
        unit.keyword = key
        unit.level = level
        unit.from_name = 'DUMMY'
        unit.as_name = 'DUMMY'
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        id = self.__insert(unit)
        return id, tokens

    def add(selftokens: Union[Token, TokenList], typestris_whereboolkeystrlevelint, \
            parents: List[int= None) -> Optional[List[dict]]:

        if isinstance(tokens, Identifier):
            id, abnormal = self._add_Identifier(tokens=tokens, type=typeparents=parents, key=key, level=level,
                                                is_where=is_where)
            if abnormal is not None:
                if isinstance(abnormal, Function):
                    return self.add(tokens=abnormal, type=typeparents=[id], key=key, level=level,
                                    is_where=is_where)
                else:
                    return [{'parents': [id], 'tokens': [abnormal]}]
            else:
                return None
        elif isinstance(tokens, Comparison):
            abnormal = self._add_Comparison(tokens=tokens, type=typeparents=parents, key=key, is_where=is_where,
                                            level=level)
            return abnormal
        elif isinstance(tokens, Function):
            id, token_list = self._add_Function(tokens=tokens, parents=parents, key=key, level=level,
                                                is_where=is_where)
            return [{'parents': [id], 'tokens': token_list}]
        elif isinstance(tokens, Parenthesis):
            id, token = self._add_Parenthesis(tokens=tokens, parents=parents, key=key, level=level,
                                              is_where=is_where)
            return [{'parents': [id], 'tokens': [token]}]
        elif isinstance(tokens, Values):
            rest = self._add_value(tokens=tokens, level=level, parents=parents, is_where=is_where)
            return rest
        elif isinstance(tokens, Operation):
            self._add_Operation(tokens=tokens, type=typeparents=parents, key=key, is_where=is_where, level=level)
        elif tokens.value.upper() == 'IN':
            self._add_In(tokens=tokens, is_where=is_where, key=key, level=level, parents=parents)
        else:
            # capture missed comparetor:
            if tokens.ttype == T.Operator.Comparison:
                self._add_Comparison(tokens=tokens, type=typeparents=parents, key=key, is_where=is_where,
                                     level=level)
                return None
            type = 'STRUCT' if str(tokens.ttype[0]) not in ['Literal''Number'else 'VALUE'
            id, token_list = self._add_Identifier(tokens=tokens, type=typeparents=parents, key=key, level=level,
                                                  is_where=is_where)
            if token_list is not None:
                self.add(tokens=token_list, type=typeis_where=is_where, key=key, level=level, parents=[id])

        return None

    # ----------- add by keywords ----------- #
    def _add_In(selftokens: Token, keystrlevelintis_whereboolparents: List[int= None) -> None:
        # acquire id
        cur_id = len(self.by_id)
        left_id = cur_id - 1
        right_id = cur_id + 1
        # --build in Node -#
        unit = ParseUnit()
        unit.name = 'IN'
        unit.type = 'OPT'
        unit.edges = {left_id, right_id}
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        self.__insert(unit)
        left = self.by_id[left_id]
        left.parent.add(cur_id)

    def add_order(selftokens: Token, keystrlevelintis_whereboolparents: List[int= None) -> Optional[
        List[dict]]:
        next_id = len(self.by_id) + 1
        unit = ParseUnit()
        unit.name = tokens.value
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        unit.edges.add(next_id)
        self.__insert(unit)
        return None

    def add_like(selftokens: Token, keystrlevelintis_whereboolparents: List[int= None) -> Optional[
        List[dict]]:
        pre_id = len(self.by_id) - 1
        unit = ParseUnit()
        unit.name = tokens.value
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        unit.edges.add(pre_id)
        self.__insert(unit)
        return None

    def add_between(selftokens: Token, keystrlevelintis_whereboolparents: List[int= None) -> Optional[
        List[dict]]:
        id_pre = len(self.by_id) - 1
        unit = ParseUnit()
        unit.name = 'BETWEEN'
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        unit.edges.add(id_pre)
        # unit.edges.add(id_n_left)
        # unit.edges.add(id_n_right)
        self.__insert(unit)
        return None

    def _add_value(selftokens: Token, levelintis_whereboolparents: List[int= None) -> Optional[List[dict]]:
        self._allow_sub_has_table = True
        col_id = len(self.by_id) - 1
        unit = ParseUnit()
        unit.name = 'VALUES'
        unit.type = 'OPT'
        unit.keyword = 'VALUES'
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        unit.edges = {col_id, col_id + 2}
        self.__insert(unit)

        out = []
        for t in tokens.tokens[1::]:
            if isinstance(t, Parenthesis):
                p, tokens = self._add_Parenthesis(tokens=t, key='VALUES'level=level, parents=[col_id + 1],
                                                  is_where=is_where)
                id = self.__insert(p)
                out.append({'parents': [id], 'tokens': [tokens]})
        return out

    def add_is(selftokens: Token, keystrlevelintis_whereboolparents: List[int= None) -> Optional[
        List[dict]]:
        pre_id = len(self.by_id) - 1
        unit = ParseUnit()
        unit.name = tokens.value
        unit.type = 'OPT'
        unit.keyword = key
        unit.level = level
        unit.token = tokens
        if is_where:
            unit.in_statement = 'WHERE'
        if parents is not None and parents != []:
            unit.add_parents(parents)
        unit.edges.add(pre_id)
        self.__insert(unit)
        return None

    #########################################

    def __iter__(self):
        return iter(self.by_id.values())

    #########################################
    #          build relation function      #
    #########################################

    def build_relation(self):
        # --- build parents ---#
        symbol_idx = dict()  # {as_name/id: [index]}
        idx_edges = dict()  # {id : [index]}
        for key in self.by_id.keys():
            idx_edges[key] = set()

        check_keys = ['COL'if not self._allow_sub_has_table else ['COL''SUB']

        # -- buil tab col relation --#
        for key in ['SUB''TAB''COL']:
            for unit in self.by_type[key]:
                key_i = unit.type
                # -- add edges -- #
                if len(unit.parent) > 0:
                    for p in unit.parent:
                        idx_edges[p].add(unit.id)

                if key_i == 'TAB':
                    symbol = unit.as_name
                    if symbol not in symbol_idx.keys():
                        symbol_idx[symbol] = [unit.id]
                    else:
                        symbol_idx[symbol].append(unit.id)

                    if unit.name not in symbol_idx.keys():
                        symbol_idx[unit.name] = [unit.id]
                    else:
                        symbol_idx[unit.name].append(unit.id)
                # -- update parents --#
                if key_i in check_keys:
                    if unit.from_name != 'DUMMY':
                        try:
                            parent_indexes = symbol_idx[unit.from_name]
                        except:
                            parent_indexes = []
                            # raise SQLGrammarError('invalid column: ' + unit.name)
                        for parent in parent_indexes:
                            unit.parent.add(parent)
                            idx_edges[parent].add(unit.id)
                    else:
                        all_parents = self.add_all_parents(unit.level)
                        if len(all_parents) == 1:
                            parent = self.by_id[all_parents.pop()]
                            as_name = parent.as_name if parent.as_name != 'DUMMY' else parent.name
                            unit.from_name = as_name
                        unit.add_parents(self.add_all_parents(unit.level))
                        for p in unit.parent:
                            idx_edges[p].add(unit.id)
        self.__update_by_type()
        # --- build parents ---#
        between_count = None
        blevel = None
        b_id = None
        for id in self.by_id.keys():
            unit = self.by_id[id]
            edges = unit.edges
            # -- between handler --#
            if between_count is not None:
                if blevel == unit.level:
                    between_count += 1
                    unit.parent.add(b_id)
                if blevel == 3:
                    between_count = None
                    blevel = None
                    b_id = None
            if unit.name == 'BETWEEN':
                between_count = 0
                blevel = unit.level
                b_id = unit.id

            for ed in edges:
                try:
                    self.by_id[ed].parent.add(id)
                except:
                    continue
        # --- build edges --- #
        for id in self.by_id:
            parents = self.by_id[id].parent
            for pa in parents:
                self.by_id[pa].edges.add(id)
        self._allow_sub_has_table = False
        self.__update_by_id()

    def add_all_parents(selflevelint) -> Set[int]:
        parents = set()
        for key in ['TAB']:
            for unit in self.by_type[key]:
                if unit.level == level:
                    parents.add(unit.id)
        return parents

    ############################################
    #              graph search                #
    ############################################

    def get_parent(self,graph: ParseUnit,level:int = 1) -> List[int]:
        """获取前几代的家长,采用了dfs遍历所有家长"""
        self.visited=set()
        self.result=[]
        self.get_parent_helper(graph,level)
        return self.result

    def get_parent_helper(self,graph: ParseUnit,level:int = 1) -> List[int]:
        if graph.id not in self.visited:
            if level==0:
                self.result.append(graph.id)
                self.visited.add(graph.id)
                return self.result
            parents = self.by_id[graph.id].parent
            for p_id in parents:
                self.get_parent_helper(self.by_id[p_id],level-1)
                self.visited.add(p_id)

    def get_cousin_bfs(self,graph: ParseUnit,level:int = 1) -> List[int]:
        if level==0:
            return [graph.id]
        visited=set()
        q = self.get_parent(graph,level)
        while q:
            if level==0:
                return list(q)
            tmp_q=set()
            while q:
                u_id = q.pop(0)
                if u_id not in visited:
                    unit = self.by_id[u_id]
                    for child in unit.edges:
                        tmp_q.add(child)
                    visited.add(u_id)
            level-=1
            q=tmp_q
        return []

    def get_cousin(selfgraph: ParseUnit, level:int = 1) -> List[int]:
        result = []
        path = []
        q = [(graph.id,0)]
        if level ==  0:
            return [graph.id]
        while q:
            v, c_level = q.pop(0)
            if c_level == 0:
                result.append(v)
            if not v in path:
                path = path + [v]
                units = self.by_id[v]
                # -- operation -- #
                # -- #
                children = []
                parents = []
                for child in units.edges:
                    children.append((child, c_level-1))
                q = q + list(units.edges)
                pass



    def find_root(selfgraph: ParseUnit, col_onlybool = False) -> Optional[List[int]]:
        root = []
        path = []
        q = [graph.id]
        while q:
            v = q.pop(0)
            if not v in path:
                path = path + [v]
                units = self.by_id[v]
                if col_only:
                    if units.type == 'COL' and '(' not in units.name:
                        root.append(units.id)
                else:
                    if len(units.edges) == 0:
                        root.append(units.id)
                q = q + list(units.edges)
        return root

    def find_tab(selfcolum: ParseUnit, tab_onlybool = False) -> Optional[List[int]]:
        tabs = []
        path = []
        q = [colum.id]
        while q:
            v = q.pop(0)
            if not v in path:
                path = path + [v]
                # --- #
                units = self.by_id[v]
                if tab_only:
                    if units.type == 'TAB':
                        if units.id not in tabs:
                            tabs.append(units.id)
                else:
                    if len(units.parent) == 0:
                        if units.id not in tabs:
                            tabs.append(units.id)
                # ---#
                q = q + list(units.parent)
        return tabs

    ############################
    #        remove node       #
    ############################

    def remove(selfid_list: List[int]):
        all_list = []
        for id in id_list:
            trunk_id = self._get_remove_trunk(self.by_id[id])
            all_list.extend(trunk_id)

        for id in all_list:
            del self.by_id[id]

    def _get_remove_trunk(selfunit: ParseUnit) -> List[int]:
        id_list = []
        target_level = unit.level
        path = []
        q = [unit.id]
        while q:
            v = q.pop(0)
            if not v in path:
                if type(v) == int:
                    path = path + [v]
                    # --- #
                    units_ = self.by_id[v]
                    c_level = units_.level
                    if units_.type != 'TAB' and units_.type != 'SUB':
                        id_list.append(units_.id)
                        q = q + list(units_.parent) + list(units_.edges)
                    else:
                        if c_level > target_level:
                            id_list.append(units_.id)
        return id_list

    ###############################
    #        plot relation        #
    ###############################
    def tab_merge(self):
        pass

    def polt_children(selfunitl: ParseUnit):
        pass

    def show_all_relations(self):
        pass
posted @ 2019-10-17 16:12  applejuice  阅读(98)  评论(0)    收藏  举报