Files
nosys_libs/fspn/protocol/connection.py
2026-01-25 13:55:46 +10:00

265 lines
9.3 KiB
Python

from ..utils.observable import Observable
from ..utils.wrapper_util import threaded
from .security import Security
from enum import Enum
import socket
import ipaddress
import struct
import time, datetime
import logging, traceback
import json
import uuid
# TODO Impossible: Hiding ip in a p2p connection hahahahahaha
HEADER_STRUCTURE = '!I?IIb' # size, encrypted, nonce_len, mac_len, is_binary
HEADER_SIZE = struct.calcsize(HEADER_STRUCTURE) # 14 bytes
MAX_CONNECTION_TRIES = 3
class EVENTS(Enum):
ON_CONNECTION = 0
ON_CONNECTION_ERROR = 1
ON_DISCONNECTION = 2
ON_MESSAGE = 3
events = [EVENTS.ON_CONNECTION, EVENTS.ON_CONNECTION_ERROR, EVENTS.ON_DISCONNECTION, EVENTS.ON_MESSAGE]
class STATUS(Enum):
DISCONNECTED = 0
CONNECTED = 1
CONNECTING = 2
HANDSHAKING = 3
ERROR = -1
class Connection(Observable):
def __init__(self, user, pmc, conn=None):
super().__init__(events)
self.security = Security(user, pmc)
self.status = STATUS.DISCONNECTED
self.address = None
self.hostname = None
self.bind_address = None
self.conn = conn
if conn:
self.set_addresses()
self.id = str(uuid.uuid4())
self.handshake_payload = None
def set_addresses(self, address=None):
if address:
self.address = address
host, port = address
try:
ipaddress.ip_address(host)
except ValueError:
self.hostname = host
else:
self.address = self.conn.getpeername()
self.bind_address = self.conn.getsockname()
@threaded
def connect(self, address, bind_address=('0.0.0.0', 0)):
self.set_addresses(address)
self.bind_address = bind_address
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
#s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
s.bind(bind_address)
s.settimeout(10)
self.conn = s
except Exception:
self.status = STATUS.ERROR
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error="Error to setup connection")
raise
self.status = STATUS.CONNECTING
logging.info(f'Socket trying to connect: {self.bind_address} -> {address}')
for i in range(MAX_CONNECTION_TRIES):
try:
s.settimeout(None)
self.conn = s
self.bind_address = s.getsockname()
s.connect(address)
self.address = s.getpeername()
break
except Exception as e:
logging.exception("ERROR")
if i < MAX_CONNECTION_TRIES - 1:
continue
else:
self.status = STATUS.ERROR
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error=f"No connection could be made in {MAX_CONNECTION_TRIES} retries: {e}")
raise
self.new_connection()
def new_connection(self):
self.handshake_create_payload()
def handshake_create_payload(self):
self.status = STATUS.HANDSHAKING
logging.info(f'Socket handshaking: {self.bind_address} -> {self.address}')
my_ecdsa_str = self.security.user
my_proof_of_work = self.security.proof_of_work
my_ecdh_pk = self.security.ecdh.public_key_to_str()
date = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.handshake_payload = {'ecdsa':my_ecdsa_str, 'pof':my_proof_of_work, 'ecdh':my_ecdh_pk, 'date':date}
payload = {'ecdsa':my_ecdsa_str, 'pof':my_proof_of_work, 'ecdh':my_ecdh_pk, 'date':date}
request_id = self.security.sign_message_ecdsa(json.dumps(payload), self.payload_signature_callback, "Handshake Connection")
self.wait_signature()
@threaded
def wait_signature(self):
# TODO Config this time as security time validation
time.sleep(120)
if self.status == STATUS.HANDSHAKING:
logging.info(f"Closing connection, payload signature time exceded")
self.status = STATUS.DISCONNECTED
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error="Payload signature time exceded")
self.close_connection()
def payload_signature_callback(self, request_id, signature):
if self.status != STATUS.HANDSHAKING:
logging.info(f"Not using signature, connection is closed")
return None
if signature:
self.handshake_payload['signature'] = signature
self.send_message(json.dumps(self.handshake_payload), False)
while self.status == STATUS.HANDSHAKING:
message = self.message_reader()["data"]
logging.error
if(message):
payload = json.loads(message)
payload_signed = payload.copy()
payload_signed.pop('signature')
self.security.handshake_validation(payload['ecdsa'], json.dumps(payload_signed), payload['signature'], payload['ecdh'], payload['pof'], payload['date'])
self.status = STATUS.CONNECTED
logging.info(f'Ready: {self.bind_address} -> {self.address}')
self.fire_event(EVENTS.ON_CONNECTION)
self.wait_message()
else:
self.status = STATUS.DISCONNECTED
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error="Payload signature is None")
self.close_connection()
@threaded
def wait_message(self):
while self.status == STATUS.CONNECTED:
try:
message = self.message_reader()
except socket.timeout as to:
logging.exception("ERROR")
continue
except (ConnectionAbortedError, EOFError, ConnectionResetError, OSError):
# TODO Maybe try to reconnect
self.status = STATUS.DISCONNECTED
self.fire_event(EVENTS.ON_DISCONNECTION)
break
except Exception as e:
logging.exception("ERROR")
self.status = STATUS.DISCONNECTED
self.fire_event(EVENTS.ON_DISCONNECTION)
break
try:
if(message):
# logging.debug(message)
self.fire_event(EVENTS.ON_MESSAGE, message=message)
except Exception as e:
logging.error(traceback.format_exc())
def close_connection(self):
logging.info(f'Socket closed: {self.address}')
self.conn.close()
self.status = STATUS.DISCONNECTED
self.fire_event(EVENTS.ON_DISCONNECTION)
@threaded
def send_binary(self, data: bytes, meta: dict = None, encrypted=True):
meta_json = json.dumps(meta or {})
meta_encoded = meta_json.encode('utf-8')
meta_len_bytes = struct.pack('!I', len(meta_encoded))
full_payload = meta_len_bytes + meta_encoded + data
nonce = b''
mac = b''
if encrypted:
nonce, full_payload, mac = self.security.encrypt_message(full_payload)
message_header = struct.pack(
HEADER_STRUCTURE,
len(nonce) + len(mac) + len(full_payload),
encrypted,
len(nonce),
len(mac),
1 # is_binary = True
)
self.conn.sendall(message_header + nonce + mac + full_payload)
@threaded
def send_message(self, message: str, encrypted=True):
nonce = b''
mac = b''
meta = b''
encoded_msg = message.encode('utf-8')
if encrypted:
nonce, encoded_msg, mac = self.security.encrypt_message(encoded_msg)
message_header = struct.pack(
HEADER_STRUCTURE,
len(nonce) + len(mac) + len(meta) + len(encoded_msg),
encrypted,
len(nonce),
len(mac),
0 # is_binary = False
)
self.conn.sendall(message_header + nonce + mac + encoded_msg)
def message_reader(self):
message_header = self.recv_all(HEADER_SIZE)
if not message_header:
self.close_connection()
return None
total_len, encrypted, nonce_len, mac_len, is_binary = struct.unpack(HEADER_STRUCTURE, message_header)
data = self.recv_all(total_len)
if encrypted:
nonce = data[:nonce_len]
mac = data[nonce_len:nonce_len + mac_len]
payload = data[nonce_len + mac_len:]
decrypted = self.security.decrypt_message(nonce, payload, mac)
else:
decrypted = data
if is_binary:
meta_length = struct.unpack('!I', decrypted[:4])[0]
meta_raw = decrypted[4:4 + meta_length]
meta = json.loads(meta_raw)
file_data = decrypted[4 + meta_length:]
return {"meta": meta, "data": file_data}
else:
return {"data": decrypted.decode('utf-8')}
def recv_all(self, n: int) -> bytes:
buffer = b''
while len(buffer) < n:
chunk = self.conn.recv(n - len(buffer))
if not chunk:
raise ConnectionError("Connection closed before receive all bytes")
buffer += chunk
return buffer