#!/usr/bin/python
#
# Copyright (c) 2013-2014 Yubico AB
# See the file COPYING for licence statement.
#
"""
Import AEADs to database.
"""

#import lib
import os
import sys
import io
import hashlib
import re
import time
import argparse
import sqlalchemy

from os.path import abspath
sys.path.append('Lib')
from pyhsm.util import key_handle_to_int
from pyhsm.yubikey import modhex_decode
import pyhsm.aead_cmd


##########################
# Functions Declarations #
##########################



#
# extract keyhandle value from the path
#
def extract_keyhandle(path, filepath):

    keyhandle = filepath.lstrip(path)
    keyhandle = keyhandle.split("/")
    return keyhandle[0]

#
#insert_query: this functions read the response fields and creates sql query. then inserts everything inside the database
#
def insert_query(publicId, aead, keyhandle, aeadobj):

    #turn the keyhandle into an integer
    keyhandle = key_handle_to_int(keyhandle)
    if not keyhandle == aead.key_handle:
        print "WARNING: keyhandle does not match aead.key_handle"
	return None

    #creates the query object
    try:
	sql = aeadobj.insert().values(public_id=publicId, keyhandle=aead.key_handle, nonce=aead.nonce, aead=aead.data)
    	#insert the query
	result = connection.execute(sql)
	return result
    except sqlalchemy.exc.IntegrityError:
	pass
    return None;


#################################
# END of functions declariation #
#################################

#######################
#                     #
# Initialization Area #
#                     #
#######################

parser = argparse.ArgumentParser(description='Import AEADs into the database')

parser.add_argument('path', action="store", type=str)
parser.add_argument('dburl', action="store")

args = vars(parser.parse_args())

if len(sys.argv) != 3:
    print("\nUsage: python import_aeads.py /path/to/keyhandle/ database_url\ni.e. python import_aeads.py /root/aeads/ mysql://root:password@localhost:3306/database_name")
    sys.exit(2)

if not os.path.isdir(sys.argv[1]):
    print("\nInvalid path, check your spelling.\n")
    sys.exit(2)

#set the path
path = args['path']
#mysql url
databaseUrl = args['dburl']

try:
    #check database connection
    engine = sqlalchemy.create_engine(databaseUrl)

    #SQLAlchemy voodoo
    metadata = sqlalchemy.MetaData()
    aeadobj = sqlalchemy.Table('aead_table', metadata, autoload=True, autoload_with=engine)
    connection = engine.connect()
except:
    print "FATAL: Database connect failure"
    sys.exit(1)

####################
# Computation area #
####################

for root, subFolders, files in os.walk(path):
    if files:
        aead = None
        nonce = None
        key_handle = None

        if not re.match(r'^[cbdefghijklnrtuv]+$', files[0]):
            continue

        #build file path
        filepath = os.path.join(root,files[0])

        #extract the key handle from the path
        keyhandle = extract_keyhandle(path, filepath)
        kh_int = pyhsm.util.key_handle_to_int(keyhandle)

        #instantiate a new aead object
        aead = pyhsm.aead_cmd.YHSM_GeneratedAEAD(None, kh_int, '')
        aead.load(filepath)

        #set the public_id
        public_id = str(files[0])

        #check it is old format aead
        if not aead.nonce:
            #configure values for oldformat
             aead.nonce = pyhsm.yubikey.modhex_decode(public_id).decode('hex')
             aead.key_handle = key_handle_to_int(keyhandle)

        if not insert_query(public_id, aead, keyhandle, aeadobj):
            print "WARNING: could not insert %s" % public_id

#close sqlalchemy
connection.close()

#exit without error
sys.exit(0)
