# Copyright 2010  Lars Wirzenius
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


'''Test paramiko by doing an sftp copy from localhost.'''


import os
import paramiko
import pwd
import socket
import subprocess
import sys
import tempfile
import time


class SSHChannelAdapter(object):

    '''Take an ssh subprocess and pretend it is a paramiko Channel.'''

    def __init__(self, proc):
        self.proc = proc

    def send(self, data):
        return os.write(self.proc.stdin.fileno(), data)

    def recv(self, count):
        try:
            return os.read(self.proc.stdout.fileno(), count)
        except socket.error, e:
            if e.args[0] in (errno.EPIPE, errno.ECONNRESET, errno.ECONNABORTED,
                             errno.EBADF):
                # Connection has closed.  Paramiko expects an empty string in
                # this case, not an exception.
                return ''
            raise

    def get_name(self):
        return 'obnam SSHChannelAdapter'

    def close(self):
        for func in [proc.stdin.close, proc.stdout.close, proc.wait]:
            try:
                func()
            except OSError:
                pass


username = pwd.getpwuid(os.getuid()).pw_name
host = 'localhost'
port = 22
path = '/tmp/zeroes'

subsystem = 'sftp'
args = ['ssh',
        '-oForwardX11=no', '-oForwardAgent=no',
        '-oClearAllForwardings=yes', '-oProtocol=2']
if port is not None:
    args.extend(['-p', str(port)])
if username is not None:
    args.extend(['-l', username])
args.extend(['-s', host, subsystem])

proc = transport = None
try:
    raise OSError(None, None, None)
    proc = subprocess.Popen(args,
                            stdin=subprocess.PIPE,
                            stdout=subprocess.PIPE,
                            close_fds=True)
except OSError:
    transport = paramiko.Transport((host, port))
    transport.connect()
    agent = paramiko.Agent()
    agent_keys = agent.get_keys()
    for key in agent_keys:
        try:
            transport.auth_publickey(username, key)
            break
        except paramiko.SSHException:
            pass
    else:
        raise Exception('no auth')

    try:
        keys = paramiko.util.load_host_keys(
                os.path.expanduser('~/.ssh/known_hosts'))
    except IOError:
        print '*** Unable to open host keys file'
        sys.exit(1)

    t = transport
    hostname = host
    key = t.get_remote_server_key()

    k = keys.lookup(hostname)
    if k is None:
        print 'unknown host %s (using it anyway)' % hostname
    else:
        key_type = key.get_name()
        if not k.has_key(key_type):
            print 'no key of type %s for %s' % (key_type, hostname)
            print 'accepting host key anyway'
        else:
            host_key = k[key_type]
            if host_key != key:
                print 'wrong host key for %s' % hostname
                sys.exit(1)
            else:
                print 'good host key for %s' % hostname

    sftp = paramiko.SFTPClient.from_transport(transport)
else:
    sftp = paramiko.SFTPClient(SSHChannelAdapter(proc))

n = 0
f = sftp.open(path)
start = time.time()
while True:
    data = f.read(32*1024)
    if not data:
        break
    n += len(data)
end = time.time()
f.close()
sftp.close()
if transport:
    transport.close()
if proc:
    proc.close()

duration = end - start
n = float(n)
print duration, n/1024/1024, 8 * n / duration / 1024 / 1024
