diff --git a/generate.py b/generate.py index 7ff1859..412f932 100644 --- a/generate.py +++ b/generate.py @@ -1,4 +1,5 @@ import os +import copy import sys import time import getopt @@ -67,12 +68,17 @@ def get_pem_from_rsa_keypair(private_key, public_key): return pripem, pubpem -def get_rsa_keypair_from_pem(private_pem): +def get_rsa_keypair_from_private_pem(private_pem): private_key = serialization.load_pem_private_key(private_pem.encode(), password=None) public_key = private_key.public_key() return private_key, public_key +def get_rsa_pubkey_from_public_pem(public_pem): + public_key = serialization.load_pem_public_key(public_pem.encode()) + return public_key + + def rsa_sign_base64(private_key, bytes_data): signature = private_key.sign(bytes_data, padding.PSS( mgf=padding.MGF1(hashes.SHA256()), @@ -87,11 +93,11 @@ def rsa_encrypt_base64(public_key, bytes_data): mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None - ))) + ))).decode() -def rsa_decrypt_base64(private_key, data): - return private_key.decrypt(base64.b64decode(data), padding.OAEP( +def rsa_decrypt_base64(private_key, str_data): + return private_key.decrypt(base64.b64decode(str_data), padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None @@ -135,7 +141,6 @@ class Parser: self.local_private_key = None self.local_public_key = None self.local_autogen_nextport = 29100 - self.pending_peers = [] self.pending_accepts = [] self.tunnel_local_endpoint = {} self.tunnel_server_reports = {} @@ -167,23 +172,58 @@ class Parser: errprint('[ERROR] No registry client name found.') exit(1) - errprint('Resolving client {} from registry ({})...'.format(client_name, self.registry_domain)) + errprint('[REGISTRY] Resolving client {} from registry ({})...'.format(client_name, self.registry_domain)) try: res = requests.get('{}/query'.format(self.registry_domain), params={ "name": client_name, }) + res = res.json() + errprint('[REGISTRY-SERVER] {}'.format(res['message'])) - remote_result = res.json() - remote_peers = remote_result['peers'] - if self.registry_client_name not in remote_peers: - errprint('This client ({}) is not accepted by {}'.format(self.registry_client_name, client_name)) + if res['code'] < 0: return {} - remote_config = rsa_decrypt_base64(remote_peers[self.registry_client_name]) - return json.loads(remote_config) + remote_result = res['data'] + remote_peers = remote_result['peers'] + if self.registry_client_name not in remote_peers: + errprint('[REGISTRY-REMOTE] This client ({}) is not accepted by {}'.format(self.registry_client_name, client_name)) + return {} + + remote_config = rsa_decrypt_base64(self.local_private_key, remote_peers[self.registry_client_name]) + remote_config = json.loads(remote_config) + + remote_peers[self.registry_client_name] = remote_config + + return remote_result except Exception: errprint(traceback.format_exc()) - errprint('Exception happened during registry client resolve') + errprint('[REGISTRY] Exception happened during registry client resolve') + return {} + + def registry_query(self, client_name): + if not self.registry_domain: + errprint('[ERROR] Cannot query from registry, domain not specified') + exit(1) + if not self.registry_client_name: + errprint('[ERROR] No registry client name found.') + exit(1) + + errprint('[REGISTRY] Querying client {}...'.format(client_name)) + try: + res = requests.get('{}/query'.format(self.registry_domain), params={ + "name": client_name, + }) + res = res.json() + errprint('[REGISTRY-SERVER] {}'.format(res['message'])) + + if res['code'] < 0: + return {} + + remote_result = res['data'] + return remote_result + except Exception: + errprint(traceback.format_exc()) + errprint('[REGISTRY] Exception happened during registry client query') return {} def registry_upload(self, content): @@ -194,28 +234,28 @@ class Parser: errprint('[ERROR] No registry client name found.') exit(1) - errprint('Registering this client ({}) with registry ({})...'.format(self.registry_client_name, self.registry_domain)) + errprint('[REGISTRY] Registering this client ({}) with registry ({})...'.format(self.registry_client_name, self.registry_domain)) try: res = requests.post('{}/register'.format(self.registry_domain), json=content) res = res.json() - errprint('[REGISTRY] {}'.format(res['message'])) + errprint('[REGISTRY-SERVER] {}'.format(res['message'])) if res['code'] < 0: return False else: return True except Exception: errprint(traceback.format_exc()) - errprint('Exception happened during registry register') + errprint('[REGISTRY] Exception happened during registry register') return False - def registry_ensure(self): + def registry_ensure(self, peers=None): private_pem, public_pem = get_pem_from_rsa_keypair(None, self.local_public_key) can_ensure = self.registry_upload({ "name": self.registry_client_name, "pubkey": public_pem, "wgkey": self.wg_pubkey, - "peers": {}, + "peers": peers or {}, "sig": rsa_sign_base64(self.local_private_key, self.wg_pubkey.encode()) }) if not can_ensure: @@ -228,6 +268,20 @@ class Parser: "mode": mode, }) + def append_input_peer_clientside(self, peer_wgkey, allowed_ip, tunnel_name): + this_peer = [] + this_peer.append("PublicKey = {}".format(peer_wgkey)) + this_peer.append("AllowedIPs = {}".format(allowed_ip)) + this_peer.append("PersistentKeepalive = 5") + this_peer.append("#use-tunnel {}".format(tunnel_name)) + self.input_peer.append(this_peer) + + def append_input_peer_serverside(self, peer_wgkey, allowed_ip): + this_peer = [] + this_peer.append("PublicKey = {}".format(peer_wgkey)) + this_peer.append("AllowedIPs = {}".format(allowed_ip)) + self.input_peer.append(this_peer) + def add_muxer(self, listen_port, forward_start, forward_size): self.container_bootstrap.append({ "type": "mux", @@ -246,10 +300,11 @@ class Parser: "listen": int(listen_port), } - def add_gost_client_with(self, remote_config): + def add_gost_client_with(self, remote_config, remote_peer_config): self.local_autogen_nextport += 1 tunnel_name = "gen{}{}".format(self.wg_hash[:8], self.local_autogen_nextport) - self.add_gost_client(tunnel_name, self.local_autogen_nextport, "{}:{}".format(remote_config['ip'], remote_config['listen'])) + self.add_gost_client(tunnel_name, self.local_autogen_nextport, "{}:{}".format(remote_config['ip'], remote_peer_config['listen'])) + self.append_input_peer_clientside(remote_config["wgkey"], remote_peer_config["allowed"], tunnel_name) def add_gost_client_mux(self, tunnel_name, mux_size, listen_port, tunnel_remote): if self.podman_user: @@ -300,10 +355,11 @@ class Parser: self.get_podman_cmd_with("podman exec {} /root/bin/udp2raw_amd64 --conf-file {} | grep ^iptables".format(self.get_container_name(), ipt_filename_inside)) )) - def add_udp2raw_client_with(self, remote_config): + def add_udp2raw_client_with(self, remote_config, remote_peer_config): self.local_autogen_nextport += 1 tunnel_name = "gen{}{}".format(self.wg_hash[:8], self.local_autogen_nextport) - self.add_udp2raw_client(tunnel_name, self.local_autogen_nextport, remote_config["password"], "{}:{}".format(remote_config['ip'], remote_config['listen'])) + self.add_udp2raw_client(tunnel_name, self.local_autogen_nextport, remote_peer_config["password"], "{}:{}".format(remote_config['ip'], remote_peer_config['listen'])) + self.append_input_peer_clientside(remote_config["wgkey"], remote_peer_config["allowed"], tunnel_name) def add_udp2raw_client_mux(self, tunnel_name, mux_size, listen_port, tunnel_password, remote_addr): self.tunnel_local_endpoint[tunnel_name] = "127.0.0.1:{}".format(listen_port) @@ -363,11 +419,12 @@ class Parser: "sni": get_subject_name_from_cert(ssl_cert_path), } - def add_trojan_client_with(self, remote_config): + def add_trojan_client_with(self, remote_config, remote_peer_config): self.local_autogen_nextport += 1 tunnel_name = "gen{}{}".format(self.wg_hash[:8], self.local_autogen_nextport) - self.add_trojan_client(tunnel_name, self.local_autogen_nextport, remote_config["password"], - "{}:{}".format(remote_config["ip"], remote_config["listen"]), remote_config["target"], ssl_sni=remote_config["sni"]) + self.add_trojan_client(tunnel_name, self.local_autogen_nextport, remote_peer_config["password"], + "{}:{}".format(remote_config["ip"], remote_peer_config["listen"]), remote_peer_config["target"], ssl_sni=remote_peer_config["sni"]) + self.append_input_peer_clientside(remote_config["wgkey"], remote_peer_config["allowed"], tunnel_name) def add_trojan_client_mux(self, tunnel_name, mux_size, listen_port, tunnel_password, remote_addr, target_port, ssl_sni=None): if self.podman_user: @@ -411,7 +468,7 @@ class Parser: private_pem = parts[0] private_pem = base64.b64decode(private_pem).decode() - self.local_private_key, self.local_public_key = get_rsa_keypair_from_pem(private_pem) + self.local_private_key, self.local_public_key = get_rsa_keypair_from_private_pem(private_pem) errprint('Loaded 1 PEM private key') continue @@ -481,12 +538,14 @@ class Parser: client_name = parts[1] client_ip = parts[2] client_allowed = parts[3] + peer_allowed = parts[4] self.pending_accepts.append({ "tunnel": tunnel_name, "client": client_name, - "ip": client_ip, + "client_ip": client_ip, "allowed": client_allowed, + "peer_allowed": peer_allowed, }) self.flag_require_registry = True else: @@ -513,13 +572,18 @@ class Parser: # registry fetch connect-to for peer_client_name in unresolved_peers: - errprint('Resolving connect-to {}...'.format(peer_client_name)) - peer_config = self.registry_resolve(peer_client_name) + errprint('[REGISTRY-RESOLVE] Resolving connect-to {}...'.format(peer_client_name)) + peer_client_config = self.registry_resolve(peer_client_name) + if not peer_client_config: + errprint('[WARN] Unable to resolve client: {}'.format(peer_client_name)) + continue + + peer_config = peer_client_config["peers"][self.registry_client_name] { "udp2raw": self.add_udp2raw_client_with, "gost": self.add_gost_client_with, "trojan": self.add_trojan_client_with, - }.get(peer_config["type"], lambda x: x)(peer_config) + }.get(peer_config["type"], lambda x, y: False)(peer_client_config, peer_config) # compile interface for line in filted_input_interface: @@ -729,6 +793,29 @@ class Parser: )) self.result_postup.extend(self.result_container_postbootstrap) + + # registry fetch accept-client + if self.flag_require_registry and self.pending_accepts: + resolved_upload_arr = {} + + for accept_info in self.pending_accepts: + peer_client_name = accept_info["client"] + errprint('[REGISTRY-RESOLVE] Resolving accept-client {}...'.format(peer_client_name)) + peer_client_config = self.registry_query(peer_client_name) + if not peer_client_config: + errprint('[WARN] Unable to resolve client: {}'.format(peer_client_name)) + continue + + self.append_input_peer_serverside(peer_client_config["wgkey"], accept_info["allowed"]) + + peer_tunnel_info = copy.copy(self.tunnel_server_reports[accept_info['tunnel']]) + peer_tunnel_info["allowed"] = accept_info["peer_allowed"] + + public_key = get_rsa_pubkey_from_public_pem(peer_client_config["pubkey"]) + resolved_upload_arr[peer_client_name] = rsa_encrypt_base64(public_key, json.dumps(peer_tunnel_info, ensure_ascii=False).encode()) + + if resolved_upload_arr: + self.registry_ensure(peers=resolved_upload_arr) def compile_peers(self): if self.flag_is_route_forward and len(self.input_peer) > 1: @@ -797,7 +884,7 @@ class Parser: for ip_cidr in current_allowed: self.result_postup.append('PostUp=ip rule add from {} lookup {}'.format(ip_cidr, current_lookup)) self.result_postdown.append('PostDown=ip rule del from {} lookup {}'.format(ip_cidr, current_lookup)) - + def get_result(self): current_time = time.strftime("%Y-%m-%d %H:%M:%S") return '''# Generated by wg-ops at {}. DO NOT EDIT.