fded6ce46cb167faaae559ff93b050c2b7d18ff1 max Mon Jun 26 08:59:00 2023 -0700 Porting hgGeneGraph to python3. refs #31563 diff --git src/hg/pyLib/pymysql/protocol.py src/hg/pyLib/pymysql/protocol.py new file mode 100644 index 0000000..41c8167 --- /dev/null +++ src/hg/pyLib/pymysql/protocol.py @@ -0,0 +1,358 @@ +# Python implementation of low level MySQL client-server protocol +# http://dev.mysql.com/doc/internals/en/client-server-protocol.html + +from .charset import MBLENGTH +from .constants import FIELD_TYPE, SERVER_STATUS +from . import err + +import struct +import sys + + +DEBUG = False + +NULL_COLUMN = 251 +UNSIGNED_CHAR_COLUMN = 251 +UNSIGNED_SHORT_COLUMN = 252 +UNSIGNED_INT24_COLUMN = 253 +UNSIGNED_INT64_COLUMN = 254 + + +def dump_packet(data): # pragma: no cover + def printable(data): + if 32 <= data < 127: + return chr(data) + return "." + + try: + print("packet length:", len(data)) + for i in range(1, 7): + f = sys._getframe(i) + print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno)) + print("-" * 66) + except ValueError: + pass + dump_data = [data[i : i + 16] for i in range(0, min(len(data), 256), 16)] + for d in dump_data: + print( + " ".join("{:02X}".format(x) for x in d) + + " " * (16 - len(d)) + + " " * 2 + + "".join(printable(x) for x in d) + ) + print("-" * 66) + print() + + +class MysqlPacket: + """Representation of a MySQL response packet. + + Provides an interface for reading/parsing the packet results. + """ + + __slots__ = ("_position", "_data") + + def __init__(self, data, encoding): + self._position = 0 + self._data = data + + def get_all_data(self): + return self._data + + def read(self, size): + """Read the first 'size' bytes in packet and advance cursor past them.""" + result = self._data[self._position : (self._position + size)] + if len(result) != size: + error = ( + "Result length not requested length:\n" + "Expected=%s. Actual=%s. Position: %s. Data Length: %s" + % (size, len(result), self._position, len(self._data)) + ) + if DEBUG: + print(error) + self.dump() + raise AssertionError(error) + self._position += size + return result + + def read_all(self): + """Read all remaining data in the packet. + + (Subsequent read() will return errors.) + """ + result = self._data[self._position :] + self._position = None # ensure no subsequent read() + return result + + def advance(self, length): + """Advance the cursor in data buffer 'length' bytes.""" + new_position = self._position + length + if new_position < 0 or new_position > len(self._data): + raise Exception( + "Invalid advance amount (%s) for cursor. " + "Position=%s" % (length, new_position) + ) + self._position = new_position + + def rewind(self, position=0): + """Set the position of the data buffer cursor to 'position'.""" + if position < 0 or position > len(self._data): + raise Exception("Invalid position to rewind cursor to: %s." % position) + self._position = position + + def get_bytes(self, position, length=1): + """Get 'length' bytes starting at 'position'. + + Position is start of payload (first four packet header bytes are not + included) starting at index '0'. + + No error checking is done. If requesting outside end of buffer + an empty string (or string shorter than 'length') may be returned! + """ + return self._data[position : (position + length)] + + def read_uint8(self): + result = self._data[self._position] + self._position += 1 + return result + + def read_uint16(self): + result = struct.unpack_from("= 7 + + def is_eof_packet(self): + # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet + # Caution: \xFE may be LengthEncodedInteger. + # If \xFE is LengthEncodedInteger header, 8bytes followed. + return self._data[0] == 0xFE and len(self._data) < 9 + + def is_auth_switch_request(self): + # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest + return self._data[0] == 0xFE + + def is_extra_auth_data(self): + # https://dev.mysql.com/doc/internals/en/successful-authentication.html + return self._data[0] == 1 + + def is_resultset_packet(self): + field_count = self._data[0] + return 1 <= field_count <= 250 + + def is_load_local_packet(self): + return self._data[0] == 0xFB + + def is_error_packet(self): + return self._data[0] == 0xFF + + def check_error(self): + if self.is_error_packet(): + self.raise_for_error() + + def raise_for_error(self): + self.rewind() + self.advance(1) # field_count == error (we already know that) + errno = self.read_uint16() + if DEBUG: + print("errno =", errno) + err.raise_mysql_exception(self._data) + + def dump(self): + dump_packet(self._data) + + +class FieldDescriptorPacket(MysqlPacket): + """A MysqlPacket that represents a specific column's metadata in the result. + + Parsing is automatically done and the results are exported via public + attributes on the class such as: db, table_name, name, length, type_code. + """ + + def __init__(self, data, encoding): + MysqlPacket.__init__(self, data, encoding) + self._parse_field_descriptor(encoding) + + def _parse_field_descriptor(self, encoding): + """Parse the 'Field Descriptor' (Metadata) packet. + + This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0). + """ + self.catalog = self.read_length_coded_string() + self.db = self.read_length_coded_string() + self.table_name = self.read_length_coded_string().decode(encoding) + self.org_table = self.read_length_coded_string().decode(encoding) + self.name = self.read_length_coded_string().decode(encoding) + self.org_name = self.read_length_coded_string().decode(encoding) + ( + self.charsetnr, + self.length, + self.type_code, + self.flags, + self.scale, + ) = self.read_struct("