diff --git a/connector.py b/connector.py index 93f03fc..12da936 100644 --- a/connector.py +++ b/connector.py @@ -38,25 +38,28 @@ class Connector: return False def request_service(self, service_name): + punch_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + punch_socket.bind(('0.0.0.0', 0)) + punch_port = punch_socket.getsockname()[1] + gathering_port = self.coordinator_port+2 + print(f'punch socket send to {(self.coordinator_host, gathering_port)}') + punch_socket.sendto(json.dumps({'token': self.token}).encode(), (self.coordinator_host, gathering_port)) + self._send_json({ 'action': 'request_service', 'service_name': service_name, 'token': self.token }) + time.sleep(3) response = self._recv_json() if response.get('status') == 'success': provider_addr = tuple(response['provider_addr']) print(f"Connecting to provider at {provider_addr}") - # 使用UDP打洞 - udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # 绑定到相同的本地端口(用于后续TCP连接) - udp_socket.bind(('0.0.0.0', 0)) - punch_port = udp_socket.getsockname()[1] # 向对方发送打洞包 - for i in range(10): - udp_socket.sendto(b'punch', provider_addr) + for i in range(20): + punch_socket.sendto(b'pong pong pong pong', provider_addr) time.sleep(0.2) # Start listening for incoming connections from provider diff --git a/coordinator.py b/coordinator.py index 391ecff..bedeac9 100644 --- a/coordinator.py +++ b/coordinator.py @@ -13,6 +13,8 @@ class Coordinator: # 初始化协调器服务端参数 self.host = host # 监听地址 self.port = port # 监听端口 + self.gathering_port = port + 2 + self.punch_addr = defaultdict() # 生成盐值用于密码加密 self.salt = secrets.token_hex(8) # 存储管理员密码哈希值(盐+密码) @@ -67,11 +69,11 @@ class Coordinator: # 服务注册流程 elif action == 'register_service': - client_token = data.get('token') - if self.validate_token(client_token, addr[0]): + connector_token = data.get('token') + if self.validate_token(connector_token, addr[0]): services = data.get('services', []) with self.lock: - self.services[client_token] = { + self.services[connector_token] = { 'services': services, # 支持的服务列表 'addr': addr, # 客户端地址信息 'conn': conn # 客户端连接套接字 @@ -82,8 +84,8 @@ class Coordinator: # 服务请求流程 elif action == 'request_service': - client_token = data.get('token') - if not self.validate_token(client_token, addr[0]): + connector_token = data.get('token') + if not self.validate_token(connector_token, addr[0]): self.send_json(conn, {'status': 'error', 'message': 'Invalid token'}) continue @@ -94,25 +96,42 @@ class Coordinator: provider_info = self.services[provider_token] provider_addr = provider_info['addr'] connector_addr = addr + count = 0 + while connector_token not in self.punch_addr and count < 11: + time.sleep(0.3) + count += 1 + if count >= 10: + print("Timeout waiting for connector to respond") + continue # 通知服务提供方进行NAT打洞 punch_msg = { 'action': 'punch_request', - 'connector_addr': connector_addr, # 请求方地址 + 'connector_addr': self.punch_addr[connector_token], 'service_name': service_name # 请求的服务名称 } + self.punch_addr.pop(connector_token) self.send_json(provider_info['conn'], punch_msg) + count = 0 + while provider_token not in self.punch_addr and count < 11: + time.sleep(0.3) + count += 1 + if count >= 10: + print("Timeout waiting for provider to respond") + continue + # 响应请求方 self.send_json(conn, { 'status': 'success', - 'provider_addr': provider_addr # 提供方地址信息 + 'provider_addr': self.punch_addr[provider_token] }) + self.punch_addr.pop(provider_token) # 使用后立即销毁令牌 with self.lock: - if client_token in self.tokens: - del self.tokens[client_token] + if connector_token in self.tokens: + del self.tokens[connector_token] else: self.send_json(conn, {'status': 'error', 'message': 'Service not available'}) except (ConnectionResetError, json.JSONDecodeError): @@ -154,6 +173,23 @@ class Coordinator: # 发送JSON数据 conn.sendall(json.dumps(data).encode()) + def punch_port_gathering(self): + gathering_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # 绑定到相同的本地端口(用于后续TCP连接) + gathering_socket.bind(('0.0.0.0', self.gathering_port)) + print(f"Starting punch port gathering on {gathering_socket.getsockname()}") + while True: + data, addr = gathering_socket.recvfrom(4096) + print(f"Received punch port from {addr}") + try: + token = json.loads(data.decode()).get('token') + if self.validate_token(token, addr[0]): + self.punch_addr[token] = addr + except (ConnectionResetError, json.JSONDecodeError): + pass + + + def start(self): # 启动协调器服务 server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -162,6 +198,11 @@ class Coordinator: server.listen(5) print(f"Coordinator listening on {self.host}:{self.port}") + gathering_thread = threading.Thread( + target=self.punch_port_gathering, + daemon=True + ) + gathering_thread.start() while True: conn, addr = server.accept() # 为每个连接创建独立线程 diff --git a/provider.py b/provider.py index 874959f..85014be 100644 --- a/provider.py +++ b/provider.py @@ -56,22 +56,22 @@ class Provider: def handle_punch_request(self, data): connector_addr = tuple(data['connector_addr']) + udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + udp_socket.bind(('0.0.0.0', 0)) + punch_port = udp_socket.getsockname()[1] + print(f"UDP punching using {punch_port}") + udp_socket.sendto(json.dumps({ + 'token': self.token + }).encode(), (self.coordinator_host, self.coordinator_port + 2)) service_name = data['service_name'] print(f"Punching hole to connector at {connector_addr}, waiting 10 seconds...") - # Wait for 10 seconds to allow the connector to initiate its punch - time.sleep(2) - - # 使用UDP打洞 - udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # 绑定到相同的本地端口(用于后续TCP连接) - udp_socket.bind(('0.0.0.0', 0)) - punch_port = udp_socket.getsockname()[1] # 向对方发送打洞包 - for i in range(10): - udp_socket.sendto(b'punch', connector_addr) + for i in range(20): + udp_socket.sendto(b'punch punch punch punch', connector_addr) time.sleep(0.2) + time.sleep(10) punch_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) punch_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)