import socket
import ssl
import time
import sys
import os
import threading
import time

lock = threading.Lock()

def tcp_handshake(ip, port):
    global TIMEOUT
    tcp_handshake_success = False

    errno = ''
    tcp_message = ''
    src_ip = ''
    src_port = ''
    # We can tell if the RST happen before or after ClientHello is sent
    # by the stage variable
    stage = 'TCPHANDSHAKE'
    timestamp = time.time()

    print("Connecting to {}:{}".format(ip, port))
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        # must set timeout at very beginning, otherwise tcp handshake does not use this timeout value
        s.settimeout(TIMEOUT)
        try:
                s.connect((ip, port))
                tcp_handshake_success = True
                message="OK"
        except socket.timeout as E:
            # This timeout will only happen with the handshake is completely and
            # the timeout happens during the TLS handshake step.
            errno = E.errno
            message = "TIMEOUT"
        except ConnectionRefusedError as E:
            errno = E.errno
            message = "REFUSED"
        except ConnectionResetError as E:
            message = "RST"
        except socket.error as E:
            # This can catch the timeout when no SYN/ACK is responsed
            # by the server during the TCP 3-way handshake.
            errno = E.errno
            message = str(E)
        finally:
            src_ip, src_port = s.getsockname()
    endtime = time.time()

    # print("Testing finished...")

    result = (timestamp, src_ip, src_port, ip, port, endtime - timestamp,
                tcp_handshake_success, errno, message)

    print(';'.join(str(x) for x in result))
    with lock:
        f.write(';'.join(str(x) for x in result) + "\n")


import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='A small program that does TCP handshakes.',
        epilog='''
        Usage:\n
        python3 syn_ping.py -iL ip_list.txt -p 80 443 -o output.csv -t 5 -m 100

        Output format:
        epoch_time;sIP;sport;dIP;dport;duration;handshake_success;errno;message''')
    parser.add_argument('-iL', '--input-list', required=True, help='A file containing dest ips or hosts')
    parser.add_argument('-o', '--output', required=True, help='output file')
    parser.add_argument('-p','--ports', nargs='+', type=int, help='dest ports', required=True)
    parser.add_argument('-t', '--timeout', type=int, default=10, help='default: 10s')
    parser.add_argument('-m', '--max-threads', type=int, default=1)
    args = parser.parse_args()

    ### Arguments
    INPUT_FILE = args.input_list
    dst_ips = [line.rstrip() for line in open(INPUT_FILE).readlines()]
    OUTPUT_FILE = args.output
    dst_ports = args.ports
    # max number of concurrent thread, the more resource the machine has, the higher value can be set.
    MAX_THREAD = args.max_threads
    TIMEOUT = args.timeout

    threads = []
    f = open(OUTPUT_FILE, 'a+')

    index = 0
    for ip in dst_ips:
        for port in dst_ports:
            index += 1
            t = threading.Thread(target = tcp_handshake, args = (ip, port, ))
            threads.append(t)
            t.start()
            # if index % MAX_THREAD == 0:
                # wait time == TIMEOUT to ensure no accumulative overhead
                #time.sleep(TIMEOUT)
            number_of_threads = len(threads)
            if number_of_threads > MAX_THREAD:
                # reduce to half max threads when exceeding half max
                for t in threads[:int(number_of_threads / 2)]:
                    t.join()

    for t in threads:
        t.join()

    f.close()
