265 lines
9.3 KiB
Python
265 lines
9.3 KiB
Python
from ..utils.observable import Observable
|
|
from ..utils.wrapper_util import threaded
|
|
from .security import Security
|
|
|
|
from enum import Enum
|
|
import socket
|
|
import ipaddress
|
|
import struct
|
|
import time, datetime
|
|
import logging, traceback
|
|
import json
|
|
import uuid
|
|
|
|
# TODO Impossible: Hiding ip in a p2p connection hahahahahaha
|
|
|
|
HEADER_STRUCTURE = '!I?IIb' # size, encrypted, nonce_len, mac_len, is_binary
|
|
HEADER_SIZE = struct.calcsize(HEADER_STRUCTURE) # 14 bytes
|
|
|
|
MAX_CONNECTION_TRIES = 3
|
|
|
|
class EVENTS(Enum):
|
|
ON_CONNECTION = 0
|
|
ON_CONNECTION_ERROR = 1
|
|
ON_DISCONNECTION = 2
|
|
ON_MESSAGE = 3
|
|
|
|
events = [EVENTS.ON_CONNECTION, EVENTS.ON_CONNECTION_ERROR, EVENTS.ON_DISCONNECTION, EVENTS.ON_MESSAGE]
|
|
|
|
class STATUS(Enum):
|
|
DISCONNECTED = 0
|
|
CONNECTED = 1
|
|
CONNECTING = 2
|
|
HANDSHAKING = 3
|
|
ERROR = -1
|
|
|
|
class Connection(Observable):
|
|
def __init__(self, user, pmc, conn=None):
|
|
super().__init__(events)
|
|
self.security = Security(user, pmc)
|
|
self.status = STATUS.DISCONNECTED
|
|
self.address = None
|
|
self.hostname = None
|
|
self.bind_address = None
|
|
self.conn = conn
|
|
if conn:
|
|
self.set_addresses()
|
|
self.id = str(uuid.uuid4())
|
|
self.handshake_payload = None
|
|
|
|
def set_addresses(self, address=None):
|
|
if address:
|
|
self.address = address
|
|
host, port = address
|
|
try:
|
|
ipaddress.ip_address(host)
|
|
except ValueError:
|
|
self.hostname = host
|
|
else:
|
|
self.address = self.conn.getpeername()
|
|
self.bind_address = self.conn.getsockname()
|
|
|
|
|
|
@threaded
|
|
def connect(self, address, bind_address=('0.0.0.0', 0)):
|
|
self.set_addresses(address)
|
|
self.bind_address = bind_address
|
|
try:
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
#s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
|
s.bind(bind_address)
|
|
s.settimeout(10)
|
|
self.conn = s
|
|
except Exception:
|
|
self.status = STATUS.ERROR
|
|
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error="Error to setup connection")
|
|
raise
|
|
|
|
self.status = STATUS.CONNECTING
|
|
logging.info(f'Socket trying to connect: {self.bind_address} -> {address}')
|
|
|
|
for i in range(MAX_CONNECTION_TRIES):
|
|
try:
|
|
s.settimeout(None)
|
|
self.conn = s
|
|
self.bind_address = s.getsockname()
|
|
s.connect(address)
|
|
self.address = s.getpeername()
|
|
break
|
|
except Exception as e:
|
|
logging.exception("ERROR")
|
|
if i < MAX_CONNECTION_TRIES - 1:
|
|
continue
|
|
else:
|
|
self.status = STATUS.ERROR
|
|
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error=f"No connection could be made in {MAX_CONNECTION_TRIES} retries: {e}")
|
|
raise
|
|
|
|
self.new_connection()
|
|
|
|
def new_connection(self):
|
|
self.handshake_create_payload()
|
|
|
|
def handshake_create_payload(self):
|
|
self.status = STATUS.HANDSHAKING
|
|
logging.info(f'Socket handshaking: {self.bind_address} -> {self.address}')
|
|
my_ecdsa_str = self.security.user
|
|
my_proof_of_work = self.security.proof_of_work
|
|
my_ecdh_pk = self.security.ecdh.public_key_to_str()
|
|
date = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
self.handshake_payload = {'ecdsa':my_ecdsa_str, 'pof':my_proof_of_work, 'ecdh':my_ecdh_pk, 'date':date}
|
|
payload = {'ecdsa':my_ecdsa_str, 'pof':my_proof_of_work, 'ecdh':my_ecdh_pk, 'date':date}
|
|
request_id = self.security.sign_message_ecdsa(json.dumps(payload), self.payload_signature_callback, "Handshake Connection")
|
|
self.wait_signature()
|
|
|
|
@threaded
|
|
def wait_signature(self):
|
|
# TODO Config this time as security time validation
|
|
time.sleep(120)
|
|
if self.status == STATUS.HANDSHAKING:
|
|
logging.info(f"Closing connection, payload signature time exceded")
|
|
self.status = STATUS.DISCONNECTED
|
|
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error="Payload signature time exceded")
|
|
self.close_connection()
|
|
|
|
def payload_signature_callback(self, request_id, signature):
|
|
if self.status != STATUS.HANDSHAKING:
|
|
logging.info(f"Not using signature, connection is closed")
|
|
return None
|
|
if signature:
|
|
self.handshake_payload['signature'] = signature
|
|
self.send_message(json.dumps(self.handshake_payload), False)
|
|
|
|
while self.status == STATUS.HANDSHAKING:
|
|
message = self.message_reader()["data"]
|
|
logging.error
|
|
if(message):
|
|
payload = json.loads(message)
|
|
payload_signed = payload.copy()
|
|
payload_signed.pop('signature')
|
|
self.security.handshake_validation(payload['ecdsa'], json.dumps(payload_signed), payload['signature'], payload['ecdh'], payload['pof'], payload['date'])
|
|
|
|
self.status = STATUS.CONNECTED
|
|
logging.info(f'Ready: {self.bind_address} -> {self.address}')
|
|
self.fire_event(EVENTS.ON_CONNECTION)
|
|
self.wait_message()
|
|
else:
|
|
self.status = STATUS.DISCONNECTED
|
|
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error="Payload signature is None")
|
|
self.close_connection()
|
|
|
|
@threaded
|
|
def wait_message(self):
|
|
while self.status == STATUS.CONNECTED:
|
|
try:
|
|
message = self.message_reader()
|
|
except socket.timeout as to:
|
|
logging.exception("ERROR")
|
|
continue
|
|
except (ConnectionAbortedError, EOFError, ConnectionResetError, OSError):
|
|
# TODO Maybe try to reconnect
|
|
self.status = STATUS.DISCONNECTED
|
|
self.fire_event(EVENTS.ON_DISCONNECTION)
|
|
break
|
|
except Exception as e:
|
|
logging.exception("ERROR")
|
|
self.status = STATUS.DISCONNECTED
|
|
self.fire_event(EVENTS.ON_DISCONNECTION)
|
|
break
|
|
|
|
try:
|
|
if(message):
|
|
# logging.debug(message)
|
|
self.fire_event(EVENTS.ON_MESSAGE, message=message)
|
|
except Exception as e:
|
|
logging.error(traceback.format_exc())
|
|
|
|
def close_connection(self):
|
|
logging.info(f'Socket closed: {self.address}')
|
|
self.conn.close()
|
|
self.status = STATUS.DISCONNECTED
|
|
self.fire_event(EVENTS.ON_DISCONNECTION)
|
|
|
|
@threaded
|
|
def send_binary(self, data: bytes, meta: dict = None, encrypted=True):
|
|
meta_json = json.dumps(meta or {})
|
|
meta_encoded = meta_json.encode('utf-8')
|
|
meta_len_bytes = struct.pack('!I', len(meta_encoded))
|
|
full_payload = meta_len_bytes + meta_encoded + data
|
|
|
|
nonce = b''
|
|
mac = b''
|
|
|
|
if encrypted:
|
|
nonce, full_payload, mac = self.security.encrypt_message(full_payload)
|
|
|
|
message_header = struct.pack(
|
|
HEADER_STRUCTURE,
|
|
len(nonce) + len(mac) + len(full_payload),
|
|
encrypted,
|
|
len(nonce),
|
|
len(mac),
|
|
1 # is_binary = True
|
|
)
|
|
|
|
self.conn.sendall(message_header + nonce + mac + full_payload)
|
|
|
|
@threaded
|
|
def send_message(self, message: str, encrypted=True):
|
|
nonce = b''
|
|
mac = b''
|
|
meta = b''
|
|
encoded_msg = message.encode('utf-8')
|
|
|
|
if encrypted:
|
|
nonce, encoded_msg, mac = self.security.encrypt_message(encoded_msg)
|
|
|
|
message_header = struct.pack(
|
|
HEADER_STRUCTURE,
|
|
len(nonce) + len(mac) + len(meta) + len(encoded_msg),
|
|
encrypted,
|
|
len(nonce),
|
|
len(mac),
|
|
0 # is_binary = False
|
|
)
|
|
|
|
self.conn.sendall(message_header + nonce + mac + encoded_msg)
|
|
|
|
|
|
def message_reader(self):
|
|
message_header = self.recv_all(HEADER_SIZE)
|
|
if not message_header:
|
|
self.close_connection()
|
|
return None
|
|
|
|
total_len, encrypted, nonce_len, mac_len, is_binary = struct.unpack(HEADER_STRUCTURE, message_header)
|
|
data = self.recv_all(total_len)
|
|
|
|
if encrypted:
|
|
nonce = data[:nonce_len]
|
|
mac = data[nonce_len:nonce_len + mac_len]
|
|
payload = data[nonce_len + mac_len:]
|
|
decrypted = self.security.decrypt_message(nonce, payload, mac)
|
|
else:
|
|
decrypted = data
|
|
|
|
if is_binary:
|
|
meta_length = struct.unpack('!I', decrypted[:4])[0]
|
|
meta_raw = decrypted[4:4 + meta_length]
|
|
meta = json.loads(meta_raw)
|
|
file_data = decrypted[4 + meta_length:]
|
|
return {"meta": meta, "data": file_data}
|
|
else:
|
|
return {"data": decrypted.decode('utf-8')}
|
|
|
|
|
|
def recv_all(self, n: int) -> bytes:
|
|
buffer = b''
|
|
while len(buffer) < n:
|
|
chunk = self.conn.recv(n - len(buffer))
|
|
if not chunk:
|
|
raise ConnectionError("Connection closed before receive all bytes")
|
|
buffer += chunk
|
|
return buffer
|