Added libs
This commit is contained in:
264
fspn/protocol/connection.py
Normal file
264
fspn/protocol/connection.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from ..utils.observable import Observable
|
||||
from ..utils.wrapper_util import threaded
|
||||
from .security import Security
|
||||
|
||||
from enum import Enum
|
||||
import socket
|
||||
import ipaddress
|
||||
import struct
|
||||
import time, datetime
|
||||
import logging, traceback
|
||||
import json
|
||||
import uuid
|
||||
|
||||
# TODO Impossible: Hiding ip in a p2p connection hahahahahaha
|
||||
|
||||
HEADER_STRUCTURE = '!I?IIb' # size, encrypted, nonce_len, mac_len, is_binary
|
||||
HEADER_SIZE = struct.calcsize(HEADER_STRUCTURE) # 14 bytes
|
||||
|
||||
MAX_CONNECTION_TRIES = 3
|
||||
|
||||
class EVENTS(Enum):
|
||||
ON_CONNECTION = 0
|
||||
ON_CONNECTION_ERROR = 1
|
||||
ON_DISCONNECTION = 2
|
||||
ON_MESSAGE = 3
|
||||
|
||||
events = [EVENTS.ON_CONNECTION, EVENTS.ON_CONNECTION_ERROR, EVENTS.ON_DISCONNECTION, EVENTS.ON_MESSAGE]
|
||||
|
||||
class STATUS(Enum):
|
||||
DISCONNECTED = 0
|
||||
CONNECTED = 1
|
||||
CONNECTING = 2
|
||||
HANDSHAKING = 3
|
||||
ERROR = -1
|
||||
|
||||
class Connection(Observable):
|
||||
def __init__(self, user, pmc, conn=None):
|
||||
super().__init__(events)
|
||||
self.security = Security(user, pmc)
|
||||
self.status = STATUS.DISCONNECTED
|
||||
self.address = None
|
||||
self.hostname = None
|
||||
self.bind_address = None
|
||||
self.conn = conn
|
||||
if conn:
|
||||
self.set_addresses()
|
||||
self.id = str(uuid.uuid4())
|
||||
self.handshake_payload = None
|
||||
|
||||
def set_addresses(self, address=None):
|
||||
if address:
|
||||
self.address = address
|
||||
host, port = address
|
||||
try:
|
||||
ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
self.hostname = host
|
||||
else:
|
||||
self.address = self.conn.getpeername()
|
||||
self.bind_address = self.conn.getsockname()
|
||||
|
||||
|
||||
@threaded
|
||||
def connect(self, address, bind_address=('0.0.0.0', 0)):
|
||||
self.set_addresses(address)
|
||||
self.bind_address = bind_address
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
#s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
s.bind(bind_address)
|
||||
s.settimeout(10)
|
||||
self.conn = s
|
||||
except Exception:
|
||||
self.status = STATUS.ERROR
|
||||
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error="Error to setup connection")
|
||||
raise
|
||||
|
||||
self.status = STATUS.CONNECTING
|
||||
logging.info(f'Socket trying to connect: {self.bind_address} -> {address}')
|
||||
|
||||
for i in range(MAX_CONNECTION_TRIES):
|
||||
try:
|
||||
s.settimeout(None)
|
||||
self.conn = s
|
||||
self.bind_address = s.getsockname()
|
||||
s.connect(address)
|
||||
self.address = s.getpeername()
|
||||
break
|
||||
except Exception as e:
|
||||
logging.exception("ERROR")
|
||||
if i < MAX_CONNECTION_TRIES - 1:
|
||||
continue
|
||||
else:
|
||||
self.status = STATUS.ERROR
|
||||
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error=f"No connection could be made in {MAX_CONNECTION_TRIES} retries: {e}")
|
||||
raise
|
||||
|
||||
self.new_connection()
|
||||
|
||||
def new_connection(self):
|
||||
self.handshake_create_payload()
|
||||
|
||||
def handshake_create_payload(self):
|
||||
self.status = STATUS.HANDSHAKING
|
||||
logging.info(f'Socket handshaking: {self.bind_address} -> {self.address}')
|
||||
my_ecdsa_str = self.security.user
|
||||
my_proof_of_work = self.security.proof_of_work
|
||||
my_ecdh_pk = self.security.ecdh.public_key_to_str()
|
||||
date = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.handshake_payload = {'ecdsa':my_ecdsa_str, 'pof':my_proof_of_work, 'ecdh':my_ecdh_pk, 'date':date}
|
||||
payload = {'ecdsa':my_ecdsa_str, 'pof':my_proof_of_work, 'ecdh':my_ecdh_pk, 'date':date}
|
||||
request_id = self.security.sign_message_ecdsa(json.dumps(payload), self.payload_signature_callback, "Handshake Connection")
|
||||
self.wait_signature()
|
||||
|
||||
@threaded
|
||||
def wait_signature(self):
|
||||
# TODO Config this time as security time validation
|
||||
time.sleep(120)
|
||||
if self.status == STATUS.HANDSHAKING:
|
||||
logging.info(f"Closing connection, payload signature time exceded")
|
||||
self.status = STATUS.DISCONNECTED
|
||||
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error="Payload signature time exceded")
|
||||
self.close_connection()
|
||||
|
||||
def payload_signature_callback(self, request_id, signature):
|
||||
if self.status != STATUS.HANDSHAKING:
|
||||
logging.info(f"Not using signature, connection is closed")
|
||||
return None
|
||||
if signature:
|
||||
self.handshake_payload['signature'] = signature
|
||||
self.send_message(json.dumps(self.handshake_payload), False)
|
||||
|
||||
while self.status == STATUS.HANDSHAKING:
|
||||
message = self.message_reader()["data"]
|
||||
logging.error
|
||||
if(message):
|
||||
payload = json.loads(message)
|
||||
payload_signed = payload.copy()
|
||||
payload_signed.pop('signature')
|
||||
self.security.handshake_validation(payload['ecdsa'], json.dumps(payload_signed), payload['signature'], payload['ecdh'], payload['pof'], payload['date'])
|
||||
|
||||
self.status = STATUS.CONNECTED
|
||||
logging.info(f'Ready: {self.bind_address} -> {self.address}')
|
||||
self.fire_event(EVENTS.ON_CONNECTION)
|
||||
self.wait_message()
|
||||
else:
|
||||
self.status = STATUS.DISCONNECTED
|
||||
self.fire_event(EVENTS.ON_CONNECTION_ERROR, error="Payload signature is None")
|
||||
self.close_connection()
|
||||
|
||||
@threaded
|
||||
def wait_message(self):
|
||||
while self.status == STATUS.CONNECTED:
|
||||
try:
|
||||
message = self.message_reader()
|
||||
except socket.timeout as to:
|
||||
logging.exception("ERROR")
|
||||
continue
|
||||
except (ConnectionAbortedError, EOFError, ConnectionResetError, OSError):
|
||||
# TODO Maybe try to reconnect
|
||||
self.status = STATUS.DISCONNECTED
|
||||
self.fire_event(EVENTS.ON_DISCONNECTION)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.exception("ERROR")
|
||||
self.status = STATUS.DISCONNECTED
|
||||
self.fire_event(EVENTS.ON_DISCONNECTION)
|
||||
break
|
||||
|
||||
try:
|
||||
if(message):
|
||||
# logging.debug(message)
|
||||
self.fire_event(EVENTS.ON_MESSAGE, message=message)
|
||||
except Exception as e:
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
def close_connection(self):
|
||||
logging.info(f'Socket closed: {self.address}')
|
||||
self.conn.close()
|
||||
self.status = STATUS.DISCONNECTED
|
||||
self.fire_event(EVENTS.ON_DISCONNECTION)
|
||||
|
||||
@threaded
|
||||
def send_binary(self, data: bytes, meta: dict = None, encrypted=True):
|
||||
meta_json = json.dumps(meta or {})
|
||||
meta_encoded = meta_json.encode('utf-8')
|
||||
meta_len_bytes = struct.pack('!I', len(meta_encoded))
|
||||
full_payload = meta_len_bytes + meta_encoded + data
|
||||
|
||||
nonce = b''
|
||||
mac = b''
|
||||
|
||||
if encrypted:
|
||||
nonce, full_payload, mac = self.security.encrypt_message(full_payload)
|
||||
|
||||
message_header = struct.pack(
|
||||
HEADER_STRUCTURE,
|
||||
len(nonce) + len(mac) + len(full_payload),
|
||||
encrypted,
|
||||
len(nonce),
|
||||
len(mac),
|
||||
1 # is_binary = True
|
||||
)
|
||||
|
||||
self.conn.sendall(message_header + nonce + mac + full_payload)
|
||||
|
||||
@threaded
|
||||
def send_message(self, message: str, encrypted=True):
|
||||
nonce = b''
|
||||
mac = b''
|
||||
meta = b''
|
||||
encoded_msg = message.encode('utf-8')
|
||||
|
||||
if encrypted:
|
||||
nonce, encoded_msg, mac = self.security.encrypt_message(encoded_msg)
|
||||
|
||||
message_header = struct.pack(
|
||||
HEADER_STRUCTURE,
|
||||
len(nonce) + len(mac) + len(meta) + len(encoded_msg),
|
||||
encrypted,
|
||||
len(nonce),
|
||||
len(mac),
|
||||
0 # is_binary = False
|
||||
)
|
||||
|
||||
self.conn.sendall(message_header + nonce + mac + encoded_msg)
|
||||
|
||||
|
||||
def message_reader(self):
|
||||
message_header = self.recv_all(HEADER_SIZE)
|
||||
if not message_header:
|
||||
self.close_connection()
|
||||
return None
|
||||
|
||||
total_len, encrypted, nonce_len, mac_len, is_binary = struct.unpack(HEADER_STRUCTURE, message_header)
|
||||
data = self.recv_all(total_len)
|
||||
|
||||
if encrypted:
|
||||
nonce = data[:nonce_len]
|
||||
mac = data[nonce_len:nonce_len + mac_len]
|
||||
payload = data[nonce_len + mac_len:]
|
||||
decrypted = self.security.decrypt_message(nonce, payload, mac)
|
||||
else:
|
||||
decrypted = data
|
||||
|
||||
if is_binary:
|
||||
meta_length = struct.unpack('!I', decrypted[:4])[0]
|
||||
meta_raw = decrypted[4:4 + meta_length]
|
||||
meta = json.loads(meta_raw)
|
||||
file_data = decrypted[4 + meta_length:]
|
||||
return {"meta": meta, "data": file_data}
|
||||
else:
|
||||
return {"data": decrypted.decode('utf-8')}
|
||||
|
||||
|
||||
def recv_all(self, n: int) -> bytes:
|
||||
buffer = b''
|
||||
while len(buffer) < n:
|
||||
chunk = self.conn.recv(n - len(buffer))
|
||||
if not chunk:
|
||||
raise ConnectionError("Connection closed before receive all bytes")
|
||||
buffer += chunk
|
||||
return buffer
|
||||
112
fspn/protocol/security.py
Normal file
112
fspn/protocol/security.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from ..utils.observable import Observable
|
||||
from ..utils.wrapper_util import singleton
|
||||
from ..utils import sha256_util, aes_util, ecdh_util, ecdsa_util
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import importlib
|
||||
|
||||
# class EcdsaKey:
|
||||
# def __init__(self) -> None:
|
||||
# self.verifying:ecdsa_util.VerifyingKey = None
|
||||
# self.signing:ecdsa_util.SigningKey = None
|
||||
|
||||
# def create_key_from_string(self, password:str):
|
||||
# self.verifying, self.signing = ecdsa_util.create_keys(password.encode())
|
||||
|
||||
# def create_key_from_bytes(self, password:bytes):
|
||||
# self.verifying, self.signing = ecdsa_util.create_keys(password)
|
||||
|
||||
# def load_verifying(self, key:str):
|
||||
# self.verifying = ecdsa_util.load_verifying_key(base64.b64decode(key.encode()))
|
||||
|
||||
# def verifying_key_to_str(self):
|
||||
# return base64.b64encode(self.verifying.to_string('compressed')).decode()
|
||||
|
||||
class UserData:
|
||||
def __init__(self):
|
||||
self.proof_of_work = None
|
||||
|
||||
# Password Manager Client
|
||||
class Pmc:
|
||||
def raiseException(self):
|
||||
raise Exception("Missing Password Manager Client")
|
||||
|
||||
def get(self, user) -> UserData:
|
||||
self.raiseException()
|
||||
|
||||
# Returns a request_id. Callback receives str:request_id str:signature
|
||||
def sign(self, data, user, callback, info=None) -> str:
|
||||
self.raiseException()
|
||||
|
||||
class EcdhKey:
|
||||
def __init__(self):
|
||||
self.public, self.private = ecdh_util.generate_keys()
|
||||
self.derived_key = None
|
||||
|
||||
def generate_derived_key(self, peer_key:str):
|
||||
ecdh_pk = ecdh_util.load_public_key_str(peer_key, True)
|
||||
shared_key = ecdh_util.generate_shared_key(self.private, ecdh_pk)
|
||||
self.derived_key = ecdh_util.generate_derived_key(shared_key)
|
||||
|
||||
def update_derived_key(self):
|
||||
self.derived_key = ecdh_util.generate_derived_key(self.derived_key)
|
||||
|
||||
def public_key_to_str(self):
|
||||
return ecdh_util.public_key_to_str(self.public, True)
|
||||
|
||||
class Security():
|
||||
def __init__(self, user, pmc:Pmc):
|
||||
self.pmc = pmc
|
||||
self.user = user
|
||||
self.proof_of_work = self.pmc.get(user).proof_of_work
|
||||
|
||||
self.peer_user = None
|
||||
self.peer_ecdsa = None
|
||||
self.ecdh = EcdhKey()
|
||||
|
||||
self.peer_pof_level = None
|
||||
self.min_proof_of_work_level = 4
|
||||
|
||||
def encrypt_message(self, message:bytes):
|
||||
return aes_util.encrypt(message, self.ecdh.derived_key)
|
||||
|
||||
def decrypt_message(self, nonce:bytes, message:bytes, mac:bytes):
|
||||
return aes_util.decrypt_and_verify(nonce, message, mac, self.ecdh.derived_key)
|
||||
|
||||
def sign_message_ecdsa(self, message:str, callback, info=None):
|
||||
hash_message = sha256_util.hash_string(message)
|
||||
return self.pmc.sign(hash_message, self.user, callback, info)
|
||||
# return base64.b64encode(ecdsa_util.sign_message(message, self.my_ecdsa.signing)).decode()
|
||||
|
||||
def check_signature_ecdsa(self, message:str, signature:str):
|
||||
hash_message = sha256_util.hash_string(message)
|
||||
return ecdsa_util.verify_message(hash_message.encode(), base64.b64decode(signature.encode()), self.peer_ecdsa)
|
||||
|
||||
def check_signature_ecdsa_vk(self, verifying_key:str, message:str, signature:str):
|
||||
hash_message = sha256_util.hash_string(message)
|
||||
return ecdsa_util.verify_message(hash_message.encode(), base64.b64decode(signature.encode()), ecdsa_util.load_verifying_key(base64.b64decode(verifying_key.encode())))
|
||||
|
||||
# def encrypt_message_ecdsa(self, message:str):
|
||||
# encrypted = ecdsa_util.encrypt_message(base64.b64decode(self.peer_ecdsa.verifying_key_to_str().encode()), message.encode())
|
||||
# return base64.b64encode(encrypted).decode()
|
||||
|
||||
# def decrypt_message_ecdsa(self, message:str):
|
||||
# return ecdsa_util.decrypt_message(self.my_ecdsa.signing.to_string(), base64.b64decode(message.encode()))
|
||||
|
||||
|
||||
def verify_proof_of_work(self, ecdsa:str, proof_of_work:str):
|
||||
hash = sha256_util.hash_string(f'{ecdsa}{proof_of_work}')
|
||||
logging.debug(f'ECDSA {ecdsa}, POF {proof_of_work}, HASH {hash}')
|
||||
if hash[0:self.min_proof_of_work_level] != "0"*self.min_proof_of_work_level:
|
||||
raise Exception(f'Proof of work below minimum level {self.min_proof_of_work_level}')
|
||||
|
||||
# TODO validate date/time
|
||||
def handshake_validation(self, peer_ecdsa, payload, payload_signature, ecdh, proof_of_work, date):
|
||||
self.peer_user = peer_ecdsa
|
||||
self.peer_ecdsa = ecdsa_util.load_verifying_key(base64.b64decode(peer_ecdsa.encode()))
|
||||
self.check_signature_ecdsa(payload, payload_signature)
|
||||
|
||||
self.verify_proof_of_work(peer_ecdsa, proof_of_work)
|
||||
self.ecdh.generate_derived_key(ecdh)
|
||||
|
||||
74
fspn/protocol/server.py
Normal file
74
fspn/protocol/server.py
Normal file
@@ -0,0 +1,74 @@
|
||||
|
||||
from ..utils.observable import Observable, Event as ObservableEvent
|
||||
from ..utils.wrapper_util import threaded
|
||||
from .connection import Connection, EVENTS as CONNECTION_EVENTS
|
||||
|
||||
from enum import Enum
|
||||
import logging, traceback
|
||||
import socket
|
||||
import random
|
||||
|
||||
class EVENTS(Enum):
|
||||
ON_START = 0
|
||||
ON_START_ERROR = 1
|
||||
ON_CONNECTION = 2
|
||||
ON_CONNECTION_ERROR = 3
|
||||
ON_DISCONNECTION = 4
|
||||
ON_MESSAGE = 5
|
||||
|
||||
class Server(Observable):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.connections:dict[tuple[str,int],Connection] = {}
|
||||
self.bind_address = None
|
||||
self.running = False
|
||||
self.user = None
|
||||
|
||||
@threaded
|
||||
def run(self, user, pmc, bind_address = ('127.0.0.1', random.randint(5000, 5999))):
|
||||
try:
|
||||
self.user = user
|
||||
if not bind_address:
|
||||
self.bind_address = ('127.0.0.1', random.randint(5000, 5999))
|
||||
else:
|
||||
self.bind_address = bind_address
|
||||
logging.info(f"Starting server on address {self.bind_address}")
|
||||
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
s.bind(self.bind_address)
|
||||
s.settimeout(10)
|
||||
s.listen(5)
|
||||
self.running = True
|
||||
self.fire_event(EVENTS.ON_START)
|
||||
logging.info(f"Listening on {self.bind_address}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
conn, addr = s.accept()
|
||||
logging.info(f"Incoming connection: {addr}")
|
||||
connection = Connection(user, pmc, conn)
|
||||
connection.subscribe_event(CONNECTION_EVENTS.ON_CONNECTION, self.on_server_connection)
|
||||
connection.subscribe_event(CONNECTION_EVENTS.ON_MESSAGE, self.on_server_message)
|
||||
connection.subscribe_event(CONNECTION_EVENTS.ON_DISCONNECTION, self.on_server_disconnection)
|
||||
self.connections[addr] = connection
|
||||
self.connections[addr].new_connection()
|
||||
except socket.timeout:
|
||||
continue
|
||||
except Exception:
|
||||
logging.error("ERROR")
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logging.error("ERROR")
|
||||
self.fire_event(EVENTS.ON_START_ERROR, error=e)
|
||||
|
||||
|
||||
def on_server_connection(self, event:ObservableEvent):
|
||||
pass
|
||||
|
||||
def on_server_message(self, event):
|
||||
pass
|
||||
|
||||
def on_server_disconnection(self, event):
|
||||
pass
|
||||
Reference in New Issue
Block a user