#!/usr/bin/env python
#
# Drizzle Client & Protocol Library
# 
# Copyright (C) 2008 Eric Day (eday@oddments.org)
# All rights reserved.
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
#     * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
#     * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
#
#     * The names of its contributors may not be used to endorse or
# promote products derived from this software without specific prior
# written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
'''
Drizzle and MySQL Protocol Test Suite

This tool requires anonymous authentication to be open on the MySQL
server.
'''

import sys
import optparse
import socket
import unittest
from prototest.mysql import *

parser = optparse.OptionParser(add_help_option=False)

parser.add_option("--help", action="help", help="Print out help details")
parser.add_option("-h", "--host", dest="host", default="localhost",
                  help="Host or IP to test", metavar="HOST")
parser.add_option("-p", "--port", dest="port", default=3306,
                  help="Port to test", metavar="PORT")
parser.add_option("-t", "--test", dest="testcase", default=None,
                  help="Test case to run", metavar="TESTCASE")

(options, args) = parser.parse_args()

# Some error codes we'll check for in the tests
BAD_HANDSHAKE = 1043
DB_ACCESS_DENIED = 1044
ACCESS_DENIED = 1045
NO_DATABASE_SELECTED = 1046
UNKNOWN_COMMAND = 1047
BAD_DB = 1049
WRONG_DB_NAME = 1102
PACKET_TOO_LARGE = 1153
PACKETS_OUT_OF_ORDER = 1156

# This is a comment range that is used in a number of tests for
# testing various interesting boundaries.
common_range = range(0, 1024) + range(65520, 65550) + range(1024*1024-20, 1024*1024+20)

class TestHandshake(unittest.TestCase):
  def setUp(self):
    self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    self.s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1);
    self.s.connect((options.host, int(options.port)))

  def tearDown(self):
    self.s.close()

  def reconnect(self):
    self.tearDown()
    self.setUp()

  def testNullClientHandshake(self):
    self.verifyServerHandshake()
    self.s.send(Packet().pack())
    (packet, data) = self.verifyPacket(1)
    result = ErrorResult(data, version_40=True)
    self.assertEqual(result.error_code, PACKETS_OUT_OF_ORDER)

  def testEmptyRangeClientHandshake(self):
    for x in common_range:
      self.verifyServerHandshake()
      self.s.send(Packet(size=x, sequence=1).pack())

      # Send the MySQL 4.1 protocol flag if possible.
      if x == 1:
        self.s.send('\x00')
      elif x > 1:
        try:
          self.s.send('\x00\x02' + '\x00' * (x - 2))
        except socket.error:
          self.assertTrue(x >= 1024*1024)
          return

      (packet, data) = self.verifyPacket(2)

      if x < 33:
        result = ErrorResult(data, version_40=True)
        self.assertEqual(result.error_code, BAD_HANDSHAKE)
      else:
        result = create_result(data, version_40=True)
        if isinstance(result, OkResult):
          self.assertEqual(result.affected_rows, 0)
          self.assertEqual(result.insert_id, 0)
        else:
          self.assertTrue(isinstance(result, ErrorResult))
          self.assertEqual(result.error_code, PACKET_TOO_LARGE)

      self.reconnect()

  def testUserOverrun(self):
    for x in range(0, 1024) + range(65520, 65550):
      server_handshake = self.verifyServerHandshake()
      client_handshake = ClientHandshake(capabilities=server_handshake.capabilities.value())

      self.s.send(Packet(size=32+x, sequence=1).pack())
      self.s.send(client_handshake.pack()[:32])
      if x > 0:
        self.s.send('\xff' * x)

      (packet, data) = self.verifyPacket(2)
      result = ErrorResult(data, version_40=True)
      if result.error_code != BAD_HANDSHAKE:
        self.assertEqual(result.error_code, PACKET_TOO_LARGE)

      self.reconnect()

  def testScrambleOverrun(self):
    for x in range(0, 256):
      server_handshake = self.verifyServerHandshake()
      client_handshake = ClientHandshake(capabilities=server_handshake.capabilities.value())

      self.s.send(Packet(size=34+x, sequence=1).pack())
      self.s.send(client_handshake.pack()[:33])
      self.s.send(chr(x))
      if x > 0:
        self.s.send('\xff' * x)

      (packet, data) = self.verifyPacket(2)
      result = create_result(data)
      if isinstance(result, OkResult):
        self.assertEqual(result.affected_rows, 0)
        self.assertEqual(result.insert_id, 0)
        self.assertEqual(result.warning_count, 0)
      else:
        self.assertTrue(isinstance(result, ErrorResult))
        self.assertTrue(result.error_code == ACCESS_DENIED or
                        result.error_code == BAD_HANDSHAKE)

      self.reconnect()

  def testDBOverrun(self):
    for x in range(0, 256):
      server_handshake = self.verifyServerHandshake()
      client_handshake = ClientHandshake(capabilities=server_handshake.capabilities.value())

      self.s.send(Packet(size=34+x, sequence=1).pack())
      self.s.send(client_handshake.pack()[:34])
      if x > 0:
        self.s.send('x' * x)

      (packet, data) = self.verifyPacket(2)
      result = create_result(data)
      if isinstance(result, OkResult):
        self.assertEqual(result.affected_rows, 0)
        self.assertEqual(result.insert_id, 0)
        self.assertEqual(result.warning_count, 0)
      else:
        self.assertTrue(isinstance(result, ErrorResult))
        if x > 64:
          self.assertEqual(result.error_code, WRONG_DB_NAME)
        else:
          self.assertTrue(result.error_code == DB_ACCESS_DENIED or
                          result.error_code == BAD_DB)

      self.reconnect()

  def testSimple(self):
    server_handshake = self.verifyServerHandshake()

    client_handshake = ClientHandshake(capabilities=server_handshake.capabilities.value())
    client_handshake.capabilities.compress = False
    data = client_handshake.pack()
    self.s.send(Packet(size=len(data), sequence=1).pack())
    self.s.send(data)

    (packet, data) = self.verifyPacket(2)
    result = OkResult(data)
    self.assertEqual(result.affected_rows, 0)
    self.assertEqual(result.insert_id, 0)
    self.assertEqual(result.warning_count, 0)

  def verifyPacket(self, sequence):
    data = self.s.recv(4)
    self.assertEqual(len(data), 4)

    packet = Packet(data)
    self.assertTrue(packet.size > 0)
    self.assertEqual(packet.sequence, sequence)

    data = self.s.recv(packet.size)
    self.assertEqual(len(data), packet.size)

    return (packet, data)
    
  def verifyServerHandshake(self):
    (packet, data) = self.verifyPacket(0)

    server_handshake = ServerHandshake(data)
    self.assertEqual(server_handshake.protocol_version, 10)
    self.assertEqual(server_handshake.null1, 0)
    self.assertEqual(server_handshake.status.value(), 2)
    self.assertEqual(server_handshake.unused, tuple([0] * 13))
    self.assertEqual(server_handshake.null2, 0)
    return server_handshake

class SocketEof(Exception):
  pass

class TestCommand(unittest.TestCase):
  def setUp(self):
    # Read server handshake
    self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    self.s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1);
    self.s.connect((options.host, int(options.port)))
    packet = Packet(self.s.recv(4))
    server_handshake = ServerHandshake(self.s.recv(packet.size))

    # Send client handshake
    client_handshake = ClientHandshake(capabilities=server_handshake.capabilities.value())
    client_handshake.capabilities.compress = False
    data = client_handshake.pack()
    self.s.send(Packet(size=len(data), sequence=1).pack())
    self.s.send(data)

    # Read server response
    data = self.s.recv(4)
    packet = Packet(data)
    data = self.s.recv(packet.size)
    result = OkResult(data)

    self._data = ''

  def tearDown(self):
    self.s.close()

  def readData(self, size):
    while len(self._data) < size:
      data = self.s.recv(8192)
      if len(data) == 0:
        raise SocketEof()
      self._data += data

    return_data = self._data[:size]
    self._data = self._data[size:]
    return return_data

  def testEmptyCommand(self):
    self.s.send(Packet().pack())
    packet = Packet(self.readData(4))
    result = create_result(self.readData(packet.size))
    self.assertTrue(isinstance(result, ErrorResult))
    self.assertEqual(result.error_code, UNKNOWN_COMMAND)

  def testInvalidCommands(self):
    for command in [CommandID.SLEEP, CommandID.CONNECT, CommandID.TIME, CommandID.DELAYED_INSERT]:
      # Test a range of sizes for each
      for x in [0, 1, 250, 251, 252, 253, 254, 255, 256, 65535, 65536]:
        data = Command(command=command, payload='x' * x).pack()
        self.s.send(Packet(size=len(data), sequence=0).pack())
        self.s.send(data)

        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, 1)
        result = create_result(self.readData(packet.size))
        self.assertTrue(isinstance(result, ErrorResult))
        self.assertEqual(result.error_code, UNKNOWN_COMMAND)

  def testQuitCommand(self):
    data = Command(command=CommandID.QUIT).pack()
    self.s.send(Packet(size=len(data), sequence=0).pack())
    self.s.send(data)
    self.assertRaises(SocketEof, self.readData, 4)

  def testQuitCommandData(self):
    data = Command(command=CommandID.QUIT, payload='x' * 65536).pack()
    self.s.send(Packet(size=len(data), sequence=0).pack())
    self.s.send(data)
    self.assertRaises(SocketEof, self.readData, 4)

  def testInitDBCommand(self):
    for x in [0, 2, 250, 251, 252, 253, 254, 255, 256, 65535, 65536, 1024*512]:
      data = Command(command=CommandID.INIT_DB, payload='x' * x).pack()
      self.s.send(Packet(size=len(data), sequence=0).pack())
      self.s.send(data)

      packet = Packet(self.readData(4))
      self.assertEqual(packet.sequence, 1)
      result = create_result(self.readData(packet.size))
      self.assertTrue(isinstance(result, ErrorResult))
      if x == 0:
        self.assertTrue(result.error_code == NO_DATABASE_SELECTED)
      else:
        self.assertTrue(result.error_code == DB_ACCESS_DENIED or
                        result.error_code == WRONG_DB_NAME or
                        result.error_code == BAD_DB)

  def testRangeQuery(self):
    for x in common_range:
      data = QueryCommand(query="SELECT '" + 'x' * x + "'").pack()
      self.s.send(Packet(size=len(data), sequence=0).pack())
      try:
        self.s.send(data)
      except socket.error:
        self.assertTrue(len(data) >= 1024*1024)
        return

      packet = Packet(self.readData(4))
      self.assertEqual(packet.sequence, 1)
      result = create_result(self.readData(packet.size))

      if isinstance(result, CountResult):
        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, 2)
        column = Column(self.readData(packet.size))
        self.assertEqual(column.catalog, 'def')
        self.assertEqual(column.db, '')
        self.assertEqual(column.table, '')
        self.assertEqual(column.orig_table, '')
        self.assertEqual(column.name, 'x' * len(column.name))
        self.assertEqual(column.orig_name, '')
        self.assertEqual(column.unused1, 12)
        # Match both 3 and 4 byte charaters depending on character set.
        self.assertTrue(column.size == x * 3 or column.size == x*4)
        self.assertTrue(column.type == ColumnType.VAR_STRING or
                        column.type == ColumnType.VARCHAR)
        self.assertEqual(column.flags.value(), 1)
        self.assertEqual(column.decimal, 31)
        self.assertEqual(column.unused2, tuple())
        self.assertEqual(column.default_value, '')

        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, 3)
        result = EofResult(self.readData(packet.size))

        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, 4)
        row = parse_row(1, self.readData(packet.size))
        self.assertEqual(row[0], 'x' * x)

        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, 5)
        result = EofResult(self.readData(packet.size))
      else:
        self.assertTrue(isinstance(result, ErrorResult))
        self.assertEqual(result.error_code, PACKET_TOO_LARGE)
        return

  def testRangeResult(self):
    for x in common_range:
      data = QueryCommand(query="SELECT REPEAT('x', %s)" % x).pack()
      self.s.send(Packet(size=len(data), sequence=0).pack())
      self.s.send(data)

      packet = Packet(self.readData(4))
      self.assertEqual(packet.sequence, 1)
      result = create_result(self.readData(packet.size))

      if isinstance(result, CountResult):
        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, 2)
        column = Column(self.readData(packet.size))
        self.assertEqual(column.name, "REPEAT('x', %s)" % x)
        self.assertTrue(column.type == ColumnType.VAR_STRING or
                        column.type == ColumnType.VARCHAR or
                        column.type == ColumnType.MEDIUM_BLOB or
                        column.type == ColumnType.BLOB)

        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, 3)
        result = EofResult(self.readData(packet.size))

        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, 4)
        row = parse_row(1, self.readData(packet.size))
        if row[0] is not None:
          self.assertEqual(len(row[0]), x)

        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, 5)
        result = EofResult(self.readData(packet.size))
      else:
        self.assertTrue(isinstance(result, ErrorResult))
        self.assertEqual(result.error_code, PACKET_TOO_LARGE)
        return

  def testRangeColumn(self):
    for x in range(0, 300) + range(1024*64 - 20, 1024*64):
      data = QueryCommand(query="SELECT 'x'" + ",'x'" * x).pack()
      self.s.send(Packet(size=len(data), sequence=0).pack())
      self.s.send(data)

      packet = Packet(self.readData(4))
      self.assertEqual(packet.sequence, 1)
      result = create_result(self.readData(packet.size))

      if isinstance(result, CountResult):
        for y in range(0, x+1):
          packet = Packet(self.readData(4))
          self.assertEqual(packet.sequence, (2+y) % 256)
          column = Column(self.readData(packet.size))
          self.assertEqual(column.name, 'x')
          self.assertTrue(column.type == ColumnType.VAR_STRING or
                          column.type == ColumnType.VARCHAR or
                          column.type == ColumnType.MEDIUM_BLOB or
                          column.type == ColumnType.BLOB)

        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, (3+x) % 256)
        result = EofResult(self.readData(packet.size))

        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, (4+x) % 256)
        row = parse_row(1+y, self.readData(packet.size))
        self.assertEqual(len(row[0]), 1)

        packet = Packet(self.readData(4))
        self.assertEqual(packet.sequence, (5+x) % 256)
        result = EofResult(self.readData(packet.size))
      else:
        self.assertTrue(isinstance(result, ErrorResult))
        self.assertEqual(result.error_code, PACKET_TOO_LARGE)
        return

  def testPingCommand(self):
    data = Command(command=CommandID.PING).pack()
    self.s.send(Packet(size=len(data), sequence=0).pack())
    self.s.send(data)
    packet = Packet(self.readData(4))
    self.assertEqual(packet.sequence, 1)
    result = create_result(self.readData(packet.size))
    self.assertTrue(isinstance(result, OkResult))

  # The SHUTDOWN command is not tested because we don't want the server going
  # away while testing.
  #
  # Ingoring the following commands for now because drizzle does not support
  # them. If support is added, be sure they support UNKNOWN_COMMAND for Drizzle.
  #
  #   FIELD_LIST = 4
  #   CREATE_DB = 5
  #   DROP_DB = 6
  #   REFRESH = 7
  #   SHUTDOWN = 8
  #   STATISTICS = 9
  #   PROCESS_INFO = 10
  #   PROCESS_KILL = 12
  #   DEBUG = 13
  #   PING = 14
  #   CHANGE_USER = 17
  #   BINLOG_DUMP = 18
  #   TABLE_DUMP = 19
  #   CONNECT_OUT = 20
  #   REGISTER_SLAVE = 21
  #   STMT_PREPARE = 22
  #   STMT_EXECUTE = 23
  #   STMT_SEND_LONG_DATA = 24
  #   STMT_CLOSE = 25
  #   STMT_RESET = 26
  #   SET_OPTION = 27
  #   STMT_FETCH = 28
  #   DAEMON = 29

  def testBadCommands(self):
    for x in range(CommandID.END, 256):
      data = Command(command=x).pack()
      self.s.send(Packet(size=len(data), sequence=0).pack())
      self.s.send(data)

      packet = Packet(self.readData(4))
      self.assertEqual(packet.sequence, 1)
      result = create_result(self.readData(packet.size))
      self.assertTrue(isinstance(result, ErrorResult))
      self.assertEqual(result.error_code, UNKNOWN_COMMAND)

class CustomTestRunner(unittest.TextTestRunner):
  def run(self, test):
    result = self._makeResult()
    test(result)
    result.printErrors()
    self.stream.writeln(result.separator2)
    run = result.testsRun
    self.stream.writeln("Ran %d test%s" % (run, run != 1 and "s" or ""))
    self.stream.writeln()
    if not result.wasSuccessful():
      self.stream.write("FAILED (")
      failed, errored = map(len, (result.failures, result.errors))
      if failed:
        self.stream.write("failures=%d" % failed)
      if errored:
        if failed: self.stream.write(", ")
        self.stream.write("errors=%d" % errored)
      self.stream.writeln(")")
    else:
      self.stream.writeln("OK")
    return result

if __name__ == '__main__':
  if options.testcase is None:
    suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__'))
  else:
    suite = unittest.TestLoader().loadTestsFromTestCase(eval(options.testcase))
  CustomTestRunner(stream=sys.stdout, verbosity=2).run(suite)
