This commit is contained in:
awin-x 2025-05-31 08:53:44 +08:00
parent 9c4a6af1bd
commit 1797c5c813
3 changed files with 70 additions and 26 deletions

View File

@ -38,25 +38,28 @@ class Connector:
return False return False
def request_service(self, service_name): def request_service(self, service_name):
punch_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
punch_socket.bind(('0.0.0.0', 0))
punch_port = punch_socket.getsockname()[1]
gathering_port = self.coordinator_port+2
print(f'punch socket send to {(self.coordinator_host, gathering_port)}')
punch_socket.sendto(json.dumps({'token': self.token}).encode(), (self.coordinator_host, gathering_port))
self._send_json({ self._send_json({
'action': 'request_service', 'action': 'request_service',
'service_name': service_name, 'service_name': service_name,
'token': self.token 'token': self.token
}) })
time.sleep(3)
response = self._recv_json() response = self._recv_json()
if response.get('status') == 'success': if response.get('status') == 'success':
provider_addr = tuple(response['provider_addr']) provider_addr = tuple(response['provider_addr'])
print(f"Connecting to provider at {provider_addr}") print(f"Connecting to provider at {provider_addr}")
# 使用UDP打洞
udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# 绑定到相同的本地端口用于后续TCP连接
udp_socket.bind(('0.0.0.0', 0))
punch_port = udp_socket.getsockname()[1]
# 向对方发送打洞包 # 向对方发送打洞包
for i in range(10): for i in range(20):
udp_socket.sendto(b'punch', provider_addr) punch_socket.sendto(b'pong pong pong pong', provider_addr)
time.sleep(0.2) time.sleep(0.2)
# Start listening for incoming connections from provider # Start listening for incoming connections from provider

View File

@ -13,6 +13,8 @@ class Coordinator:
# 初始化协调器服务端参数 # 初始化协调器服务端参数
self.host = host # 监听地址 self.host = host # 监听地址
self.port = port # 监听端口 self.port = port # 监听端口
self.gathering_port = port + 2
self.punch_addr = defaultdict()
# 生成盐值用于密码加密 # 生成盐值用于密码加密
self.salt = secrets.token_hex(8) self.salt = secrets.token_hex(8)
# 存储管理员密码哈希值(盐+密码) # 存储管理员密码哈希值(盐+密码)
@ -67,11 +69,11 @@ class Coordinator:
# 服务注册流程 # 服务注册流程
elif action == 'register_service': elif action == 'register_service':
client_token = data.get('token') connector_token = data.get('token')
if self.validate_token(client_token, addr[0]): if self.validate_token(connector_token, addr[0]):
services = data.get('services', []) services = data.get('services', [])
with self.lock: with self.lock:
self.services[client_token] = { self.services[connector_token] = {
'services': services, # 支持的服务列表 'services': services, # 支持的服务列表
'addr': addr, # 客户端地址信息 'addr': addr, # 客户端地址信息
'conn': conn # 客户端连接套接字 'conn': conn # 客户端连接套接字
@ -82,8 +84,8 @@ class Coordinator:
# 服务请求流程 # 服务请求流程
elif action == 'request_service': elif action == 'request_service':
client_token = data.get('token') connector_token = data.get('token')
if not self.validate_token(client_token, addr[0]): if not self.validate_token(connector_token, addr[0]):
self.send_json(conn, {'status': 'error', 'message': 'Invalid token'}) self.send_json(conn, {'status': 'error', 'message': 'Invalid token'})
continue continue
@ -94,25 +96,42 @@ class Coordinator:
provider_info = self.services[provider_token] provider_info = self.services[provider_token]
provider_addr = provider_info['addr'] provider_addr = provider_info['addr']
connector_addr = addr connector_addr = addr
count = 0
while connector_token not in self.punch_addr and count < 11:
time.sleep(0.3)
count += 1
if count >= 10:
print("Timeout waiting for connector to respond")
continue
# 通知服务提供方进行NAT打洞 # 通知服务提供方进行NAT打洞
punch_msg = { punch_msg = {
'action': 'punch_request', 'action': 'punch_request',
'connector_addr': connector_addr, # 请求方地址 'connector_addr': self.punch_addr[connector_token],
'service_name': service_name # 请求的服务名称 'service_name': service_name # 请求的服务名称
} }
self.punch_addr.pop(connector_token)
self.send_json(provider_info['conn'], punch_msg) self.send_json(provider_info['conn'], punch_msg)
count = 0
while provider_token not in self.punch_addr and count < 11:
time.sleep(0.3)
count += 1
if count >= 10:
print("Timeout waiting for provider to respond")
continue
# 响应请求方 # 响应请求方
self.send_json(conn, { self.send_json(conn, {
'status': 'success', 'status': 'success',
'provider_addr': provider_addr # 提供方地址信息 'provider_addr': self.punch_addr[provider_token]
}) })
self.punch_addr.pop(provider_token)
# 使用后立即销毁令牌 # 使用后立即销毁令牌
with self.lock: with self.lock:
if client_token in self.tokens: if connector_token in self.tokens:
del self.tokens[client_token] del self.tokens[connector_token]
else: else:
self.send_json(conn, {'status': 'error', 'message': 'Service not available'}) self.send_json(conn, {'status': 'error', 'message': 'Service not available'})
except (ConnectionResetError, json.JSONDecodeError): except (ConnectionResetError, json.JSONDecodeError):
@ -154,6 +173,23 @@ class Coordinator:
# 发送JSON数据 # 发送JSON数据
conn.sendall(json.dumps(data).encode()) conn.sendall(json.dumps(data).encode())
def punch_port_gathering(self):
gathering_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# 绑定到相同的本地端口用于后续TCP连接
gathering_socket.bind(('0.0.0.0', self.gathering_port))
print(f"Starting punch port gathering on {gathering_socket.getsockname()}")
while True:
data, addr = gathering_socket.recvfrom(4096)
print(f"Received punch port from {addr}")
try:
token = json.loads(data.decode()).get('token')
if self.validate_token(token, addr[0]):
self.punch_addr[token] = addr
except (ConnectionResetError, json.JSONDecodeError):
pass
def start(self): def start(self):
# 启动协调器服务 # 启动协调器服务
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -162,6 +198,11 @@ class Coordinator:
server.listen(5) server.listen(5)
print(f"Coordinator listening on {self.host}:{self.port}") print(f"Coordinator listening on {self.host}:{self.port}")
gathering_thread = threading.Thread(
target=self.punch_port_gathering,
daemon=True
)
gathering_thread.start()
while True: while True:
conn, addr = server.accept() conn, addr = server.accept()
# 为每个连接创建独立线程 # 为每个连接创建独立线程

View File

@ -56,22 +56,22 @@ class Provider:
def handle_punch_request(self, data): def handle_punch_request(self, data):
connector_addr = tuple(data['connector_addr']) connector_addr = tuple(data['connector_addr'])
udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
udp_socket.bind(('0.0.0.0', 0))
punch_port = udp_socket.getsockname()[1]
print(f"UDP punching using {punch_port}")
udp_socket.sendto(json.dumps({
'token': self.token
}).encode(), (self.coordinator_host, self.coordinator_port + 2))
service_name = data['service_name'] service_name = data['service_name']
print(f"Punching hole to connector at {connector_addr}, waiting 10 seconds...") print(f"Punching hole to connector at {connector_addr}, waiting 10 seconds...")
# Wait for 10 seconds to allow the connector to initiate its punch
time.sleep(2)
# 使用UDP打洞
udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# 绑定到相同的本地端口用于后续TCP连接
udp_socket.bind(('0.0.0.0', 0))
punch_port = udp_socket.getsockname()[1]
# 向对方发送打洞包 # 向对方发送打洞包
for i in range(10): for i in range(20):
udp_socket.sendto(b'punch', connector_addr) udp_socket.sendto(b'punch punch punch punch', connector_addr)
time.sleep(0.2) time.sleep(0.2)
time.sleep(10)
punch_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) punch_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
punch_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) punch_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)