使用自己的Python函数处理Protobuf中的字符串编码

我目前所在的项目是一个老项目,里面的字符串编码有点乱,数据库中有些是GB2312,有些是UTF8;代码中有些是GBK,有些是UTF8,代码中转来转去,经常是不太清楚当前这个字符串是什么编码,由于是老项目,也没去修改。最近合服脚本由项目上进行维护了,我拿到脚本看了看是Python写的,我之前也没学习过Python,只有现学现用。

数据库中使用了Protobuf,这里面也有字符串,编码也是有GBK,也有UTF8编码的,而且是交叉使用尴尬,有过合服经验的同学应该知道,这里会涉及一些修改,比如名字冲突需要改名。Protobuf中的名字修改就需要先解析出来修改了再序列化回去。这个时候问题来了,Protobuf默认是使用的UTF8编码进行解析(Decode)与序列化的(Encode),可以参见:google.protobuf.internal中的decoder.py中的函数:

def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
  """Returns a decoder for a string field."""

  local_DecodeVarint = _DecodeVarint
  local_unicode = unicode

  assert not is_packed
  if is_repeated:
    tag_bytes = encoder.TagBytes(field_number,
                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
    tag_len = len(tag_bytes)
    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
      value = field_dict.get(key)
      if value is None:
        value = field_dict.setdefault(key, new_default(message))
      while 1:
        (size, pos) = local_DecodeVarint(buffer, pos)
        new_pos = pos + size
        if new_pos > end:
          raise _DecodeError('Truncated string.')
        value.append(local_unicode(buffer[pos:new_pos], 'utf-8'))
        # Predict that the next tag is another copy of the same repeated field.
        pos = new_pos + tag_len
        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
          # Prediction failed.  Return.
          return new_pos
    return DecodeRepeatedField
  else:
    def DecodeField(buffer, pos, end, message, field_dict):
      (size, pos) = local_DecodeVarint(buffer, pos)
      new_pos = pos + size
      if new_pos > end:
        raise _DecodeError('Truncated string.')
      field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8')
      return new_pos
    return DecodeField

以及encoder.py中的函数

def StringEncoder(field_number, is_repeated, is_packed):
  """Returns an encoder for a string field."""

  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
  local_EncodeVarint = _EncodeVarint
  local_len = len
  assert not is_packed
  if is_repeated:
    def EncodeRepeatedField(write, value):
      for element in value:
        encoded = element.encode('utf-8')
        write(tag)
        local_EncodeVarint(write, local_len(encoded))
        write(encoded)
    return EncodeRepeatedField
  else:
    def EncodeField(write, value):
      encoded = value.encode('utf-8')
      write(tag)
      local_EncodeVarint(write, local_len(encoded))
      return write(encoded)
    return EncodeField

如果Protobuf中的字符串编码为非UTF8编码,则在解析(Decode)的过程中会出现异常(有点奇怪的是我同事的电脑上没出现异常):

'utf8' codec can't decode byte……

我们有没有一个方法在不改变Protobuf原来的代码的情况下使用自己的函数来进行解析呢,这是我首先想到的,由于没学习过Python,恶补了一下Python基础后,研究发现Protobuf是把Decode的函数入口放在了一个数组中,在引入模块的时候就会自动初始化这些入口函数,然后保存到各个Protobuf类中,各个PB类都有一个decoders_by_tag字典,这个字典就存放了各种数据类型的解析函数入口地址。

通过上面的代码可以看出,具体解析函数(DecodeField)是放在一个闭包中的,不能直接修改,所以必须整个(StringDecoder)替换。通过深入研究,终于发现了其设置的入口,在google.protobuf.internal的type_checkers.py中有这样一段代码:

# Maps from field types to encoder constructors.
TYPE_TO_ENCODER = {
    _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder,
    _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder,
    _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder,
    _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder,
    _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder,
    _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder,
    _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder,
    _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder,
    _FieldDescriptor.TYPE_STRING: encoder.StringEncoder,
    _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder,
    _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder,
    _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder,
    _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder,
    _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder,
    _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder,
    _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder,
    _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder,
    _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder,
    }


# Maps from field types to sizer constructors.
TYPE_TO_SIZER = {
    _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer,
    _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer,
    _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer,
    _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer,
    _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer,
    _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer,
    _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer,
    _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer,
    _FieldDescriptor.TYPE_STRING: encoder.StringSizer,
    _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer,
    _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer,
    _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer,
    _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer,
    _FieldDescriptor.TYPE_ENUM: encoder.EnumSizer,
    _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer,
    _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer,
    _FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer,
    _FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer,
    }


# Maps from field type to a decoder constructor.
TYPE_TO_DECODER = {
    _FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder,
    _FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder,
    _FieldDescriptor.TYPE_INT64: decoder.Int64Decoder,
    _FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder,
    _FieldDescriptor.TYPE_INT32: decoder.Int32Decoder,
    _FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder,
    _FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder,
    _FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder,
    _FieldDescriptor.TYPE_STRING: decoder.StringDecoder,
    _FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder,
    _FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder,
    _FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder,
    _FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder,
    _FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder,
    _FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder,
    _FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder,
    _FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder,
    _FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder,
    }

第一个是序列化(Encoder)的函数入口,第二个是计算大小的函数入口,第三个就是解析(Decoder)的入口,我们可以看到这里映射了所有类型的处理函数入口,那我们把这个入口函数替换成我们自己的函数,就可以根据实际需要进行处理了。

这里我们需要特别注意的是Protobuf中的各个类都是在模块导入的时候就初始化好了,所以,如果我们要修改入口函数,必须在PB各类引入之前进行修改。为此我写了一个模块文件:protobuf_hack.py,这个模块必须先于PB类import,其内容如下:

from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf.internal import type_checkers
from google.protobuf import reflection
from google.protobuf import message

def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
  """Returns a decoder for a string field."""

  local_DecodeVarint = _DecodeVarint
  local_unicode = unicode

  assert not is_packed
  if is_repeated:
    tag_bytes = encoder.TagBytes(field_number,
                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
    tag_len = len(tag_bytes)
    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
      value = field_dict.get(key)
      if value is None:
        value = field_dict.setdefault(key, new_default(message))
      while 1:
        (size, pos) = local_DecodeVarint(buffer, pos)
        new_pos = pos + size
        if new_pos > end:
          raise _DecodeError('Truncated string.')
        value.append(local_unicode(buffer[pos:new_pos], 'gbk'))
        # Predict that the next tag is another copy of the same repeated field.
        pos = new_pos + tag_len
        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
          # Prediction failed.  Return.
          return new_pos
    return DecodeRepeatedField
  else:
    def DecodeField(buffer, pos, end, message, field_dict):
      (size, pos) = local_DecodeVarint(buffer, pos)
      new_pos = pos + size
      if new_pos > end:
        raise _DecodeError('Truncated string.')
      field_dict[key] = local_unicode(buffer[pos:new_pos], 'gbk')
      return new_pos
    return DecodeField

type_checkers.TYPE_TO_DECODER[type_checkers._FieldDescriptor.TYPE_STRING] = StringDecoder 


这样,我们可以把所有PB中的字符串解析按GBK编码解析了。但是项目中的字符串并不是所有的字符串都是GBK编码的,也有UTF8编码的,为了支持两种编码,我做了一个处理,就是先尝试使用一种编码解析,如果出现异常,再使用另一种编码进行解析,这样就保证了我们所有的字符串都可以正确解析。理想很丰满,现实很骨感,解析是正确了,但是如果我们序列化回去在服务器程序中去使用的时候就会出现乱码,因为原来的GBK或者UTF8统一成UTF8编码了,当然,我们也可以继续像Decoder调用自己的函数一样处理Encoder,但是在Encoder中我们并不知道这个字符串原来在数据库中是什么编码,也没有PB以及字段信息,无法差别处理。

至此,算是白忙活了,无法满足需要。

如果我们能够只修改我们指定的PB类的处理函数就好了,因为我们可以找出哪些PB的字符串是GBK编码的。再次经过深入研究,总算是做到了。

在这里有一个函数帮了我大忙,reflection.py中的ParseMessage函数,我们看一下:

def ParseMessage(descriptor, byte_str):
  """Generate a new Message instance from this Descriptor and a byte string.

  Args:
    descriptor: Protobuf Descriptor object
    byte_str: Serialized protocol buffer byte string

  Returns:
    Newly created protobuf Message object.
  """

  class _ResultClass(message.Message):
    __metaclass__ = GeneratedProtocolMessageType
    DESCRIPTOR = descriptor

  new_msg = _ResultClass()
  new_msg.ParseFromString(byte_str)
  return new_msg


这个函数其实就是通过描述符信息(descriptor)来解析二进制串,生成一个新的PB消息实例。这中间的关键就是函数中的那个动态生成类实例的代码,在这里会走一次PB类的初始化流程,即会初始化我们所需要的Decoder以及Encoder函数映射字典。为了工作需要,我修改一下这个函数:

def ParseMessage(descriptor):
  class _ResultClass(message.Message):
    __metaclass__ = reflection.GeneratedProtocolMessageType
    DESCRIPTOR = descriptor

  new_msg = _ResultClass()
  return new_msg

然后加入我们需要使用自定义函数处理的PB类,注意这里一定是所需要的最小的PB结构。
def hacker(msg):
    ParseMessage(msg.DESCRIPTOR)
	
def hack_pb():
    #修改默认的字符串处理函数入口为自定义函数
    type_checkers.TYPE_TO_DECODER[type_checkers._FieldDescriptor.TYPE_STRING] = StringDecoder
    type_checkers.TYPE_TO_ENCODER[type_checkers._FieldDescriptor.TYPE_STRING] = StringEncoder
    type_checkers.TYPE_TO_SIZER[type_checkers._FieldDescriptor.TYPE_STRING] = StringSizer

    try:
        # 这里加入我们需要修改的PB类
        hacker(DbProto.DB_FriendAssetEntry_PB)
    except Exception as e:
        print(e)

    #还原字符串处理函数入口
    type_checkers.TYPE_TO_DECODER[type_checkers._FieldDescriptor.TYPE_STRING] = decoder.StringDecoder
    type_checkers.TYPE_TO_ENCODER[type_checkers._FieldDescriptor.TYPE_STRING] = encoder.StringEncoder
    type_checkers.TYPE_TO_SIZER[type_checkers._FieldDescriptor.TYPE_STRING] = encoder.StringSizer

由于Encode的时候Protobuf是先计算字段的长度,然后再处理的各字段,所以我们还需要把计算大小的函数使用自定义函数,否则再次解析会出问题。

现在基本上满足了需要,算是大功告成了!

细心的读者,不知你发现没,这里还是有一个问题,目前无法解决的问题,就是如果我们一个最小的PB中如果有两个字符串字段,采用的不同的编码怎么办?一般情况下,正常的设计者不会这样做,但是就像我们项目中的编码混乱一样,如果一个不小心就搞成不一样的编码就悲剧了!如果哪位高手有此解决方案,欢迎分享!!!

把整个文件附上:

from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf.internal import type_checkers
from google.protobuf import reflection
from google.protobuf import message

def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
    """Returns a decoder for a string field."""

    local_DecodeVarint = decoder._DecodeVarint
    local_unicode = unicode

    assert not is_packed
    if is_repeated:
        tag_bytes = encoder.TagBytes(field_number,
                                     wire_format.WIRETYPE_LENGTH_DELIMITED)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while 1:
                (size, pos) = local_DecodeVarint(buffer, pos)
                new_pos = pos + size
                if new_pos > end:
                    raise decoder._DecodeError('Truncated string.')
                str = '' #这里先尝试使用UTF8编码进行解析,如果出现异常则尝试使用GBK编码解析
                try:
                    str = local_unicode(buffer[pos:new_pos], 'utf-8')
                except Exception as e:
                    try:
                        str = local_unicode(buffer[pos:new_pos], 'gbk')
                    except Exception as e1:
                        str = ''

                value.append(str)
                # Predict that the next tag is another copy of the same repeated field.
                pos = new_pos + tag_len
                if buffer[new_pos:pos] != tag_bytes or new_pos == end:
                    # Prediction failed.  Return.
                    return new_pos

        return DecodeRepeatedField
    else:
        def DecodeField(buffer, pos, end, message, field_dict):
            (size, pos) = local_DecodeVarint(buffer, pos)
            new_pos = pos + size
            if new_pos > end:
                raise decoder._DecodeError('Truncated string.')

            str = '' #这里先尝试使用UTF8编码进行解析,如果出现异常则尝试使用GBK编码解析
            try:
                str = local_unicode(buffer[pos:new_pos], 'utf-8')
            except Exception as e:
                try:
                    str = local_unicode(buffer[pos:new_pos], 'gbk')
                except Exception as e1:
                    str = ''

            field_dict[key] = str
            return new_pos

        return DecodeField


def StringEncoder(field_number, is_repeated, is_packed):
    """Returns an encoder for a string field."""

    tag = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
    local_EncodeVarint = encoder._EncodeVarint
    local_len = len
    assert not is_packed
    if is_repeated:
        def EncodeRepeatedField(write, value):
            for element in value:
                encoded = element.encode('gbk') #序列化的时候就直接使用GBK编码了
                write(tag)
                local_EncodeVarint(write, local_len(encoded))
                write(encoded)

        return EncodeRepeatedField
    else:
        def EncodeField(write, value):
            encoded = value.encode('gbk') #序列化的时候就直接使用GBK编码了
            write(tag)
            local_EncodeVarint(write, local_len(encoded))
            return write(encoded)

        return EncodeField

def StringSizer(field_number, is_repeated, is_packed):
    """Returns a sizer for a string field."""

    tag_size = encoder._TagSize(field_number)
    local_VarintSize = encoder._VarintSize
    local_len = len
    assert not is_packed
    if is_repeated:
        def RepeatedFieldSize(value):
            result = tag_size * len(value)
            for element in value:
                l = local_len(element.encode('gbk')) #注意序列化前计算长度时也需要使用与序列化相同的编码,否则会出错
                result += local_VarintSize(l) + l
            return result

        return RepeatedFieldSize
    else:
        def FieldSize(value):
            l = local_len(value.encode('gbk')) #注意序列化前计算长度时也需要使用与序列化相同的编码,否则会出错
            return tag_size + local_VarintSize(l) + l

        return FieldSize

def ParseMessage(descriptor):
  class _ResultClass(message.Message):
    __metaclass__ = reflection.GeneratedProtocolMessageType
    DESCRIPTOR = descriptor

  new_msg = _ResultClass()
  return new_msg

def hacker(msg):
    ParseMessage(msg.DESCRIPTOR)

def hack_pb():
    # 修改默认的字符串处理函数入口为自定义函数
    type_checkers.TYPE_TO_DECODER[type_checkers._FieldDescriptor.TYPE_STRING] = StringDecoder
    type_checkers.TYPE_TO_ENCODER[type_checkers._FieldDescriptor.TYPE_STRING] = StringEncoder
    type_checkers.TYPE_TO_SIZER[type_checkers._FieldDescriptor.TYPE_STRING] = StringSizer

    try:
        # 这里加入我们需要修改的PB类,注意这里需要自行import DbProto模块
        hacker(DbProto.DB_FriendAssetEntry_PB)
    except Exception as e:
        print(e)

    # 还原字符串处理函数入口
    type_checkers.TYPE_TO_DECODER[type_checkers._FieldDescriptor.TYPE_STRING] = decoder.StringDecoder
    type_checkers.TYPE_TO_ENCODER[type_checkers._FieldDescriptor.TYPE_STRING] = encoder.StringEncoder
    type_checkers.TYPE_TO_SIZER[type_checkers._FieldDescriptor.TYPE_STRING] = encoder.StringSizer

#这里让其在引入模块时自动执行
hack_pb()


posted @ 2016-11-04 19:03  witton  阅读(2348)  评论(0编辑  收藏  举报