411 lines
15 KiB
Python
411 lines
15 KiB
Python
# connector.py
|
||
import socket
|
||
import json
|
||
import threading
|
||
import time
|
||
import uuid
|
||
import heapq
|
||
from collections import deque
|
||
|
||
|
||
class ReliableChannel:
|
||
"""可靠传输通道类(与provider中的类似)"""
|
||
|
||
def __init__(self, udp_sock, remote_addr, conn_id):
|
||
self.udp_sock = udp_sock
|
||
self.remote_addr = remote_addr
|
||
self.send_buffer = deque()
|
||
self.recv_buffer = {}
|
||
self.recv_queue = deque() # 接收队列
|
||
self.expected_seq = 0
|
||
self.last_ack = -1
|
||
self.ack_interval = 0.2
|
||
self.conn_id = conn_id
|
||
self.last_ack_time = 0
|
||
self.window_size = 10
|
||
self.retransmit_timers = {}
|
||
self.retransmit_timeout = 1
|
||
self.running = True
|
||
|
||
def send(self, data, is_control=False):
|
||
if not self.running:
|
||
return
|
||
|
||
seq = self.expected_seq
|
||
packet = {
|
||
'seq': seq,
|
||
'data': data.hex(),
|
||
'is_control': is_control,
|
||
'timestamp': time.time()
|
||
}
|
||
|
||
self.send_buffer.append(packet)
|
||
self.retransmit_timers[seq] = time.time()
|
||
self.expected_seq += 1
|
||
self._send_from_buffer()
|
||
|
||
def _send_from_buffer(self):
|
||
while self.send_buffer and self.send_buffer[0]['seq'] <= self.last_ack:
|
||
self.send_buffer.popleft()
|
||
|
||
for packet in list(self.send_buffer)[:self.window_size]:
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'data',
|
||
'packet': packet,
|
||
'conn_id': self.conn_id
|
||
}).encode(), self.remote_addr)
|
||
|
||
def process_ack(self, ack_seq):
|
||
print(f"收到ack: {ack_seq}")
|
||
if ack_seq > self.last_ack:
|
||
print(f"更新last_ack: {ack_seq}")
|
||
self.last_ack = ack_seq
|
||
for seq in list(self.retransmit_timers.keys()):
|
||
if seq <= ack_seq:
|
||
self.retransmit_timers.pop(seq, None)
|
||
self._send_from_buffer()
|
||
|
||
def process_packet(self, packet):
|
||
seq = packet['seq']
|
||
data = bytes.fromhex(packet['data'])
|
||
|
||
# if time.time() - self.last_ack_time > self.ack_interval:
|
||
self.send_ack()
|
||
self.last_ack_time = time.time()
|
||
|
||
if seq == self.expected_seq:
|
||
self.expected_seq += 1
|
||
self._process_buffered()
|
||
return data
|
||
elif seq > self.expected_seq:
|
||
self.recv_buffer[seq] = data
|
||
return None
|
||
else:
|
||
return None
|
||
|
||
def _process_buffered(self):
|
||
while self.expected_seq in self.recv_buffer:
|
||
data = self.recv_buffer.pop(self.expected_seq)
|
||
# self.expected_seq += 1
|
||
# self.recv_queue.append(data) # 添加到队列
|
||
|
||
def get_received_data(self):
|
||
"""从接收队列中取出数据"""
|
||
return self.recv_queue.popleft() if self.recv_queue else None
|
||
|
||
def send_ack(self):
|
||
ack_packet = {
|
||
'action': 'ack',
|
||
'ack_seq': self.expected_seq - 1,
|
||
'window': self.window_size,
|
||
'conn_id': self.conn_id
|
||
}
|
||
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
|
||
|
||
def check_retransmit(self):
|
||
now = time.time()
|
||
for seq, send_time in list(self.retransmit_timers.items()):
|
||
if now - send_time > self.retransmit_timeout:
|
||
print(f'有重传{seq}')
|
||
for packet in self.send_buffer:
|
||
if packet['seq'] == seq:
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'data',
|
||
'packet': packet,
|
||
'conn_id': self.conn_id
|
||
}).encode(), self.remote_addr)
|
||
self.retransmit_timers[seq] = now
|
||
break
|
||
|
||
def close(self):
|
||
self.running = False
|
||
self.send_buffer.clear()
|
||
self.recv_buffer.clear()
|
||
self.retransmit_timers.clear()
|
||
|
||
|
||
class ServiceConnector:
|
||
def __init__(self, coordinator_addr, service_name, local_port):
|
||
self.coordinator_addr = coordinator_addr
|
||
self.service_name = service_name
|
||
self.local_port = local_port
|
||
self.client_id = f"connector-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 创建UDP套接字
|
||
self.udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||
self.udp_sock.bind(('0.0.0.0', 0))
|
||
self.udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024) # 1MB接收缓冲区
|
||
self.udp_sock.settimeout(5)
|
||
|
||
# 创建TCP监听套接字
|
||
self.tcp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
self.tcp_sock.bind(('127.0.0.1', local_port))
|
||
self.tcp_sock.listen(10) # 增加监听队列
|
||
print(f"本地端口映射: 127.0.0.1:{local_port} -> 远程服务 '{service_name}'")
|
||
|
||
# 连接状态
|
||
self.active_connections = {}
|
||
self.reliable_channels = {}
|
||
self.provider_addr = None
|
||
self.internal_port = None
|
||
self.target_id = None
|
||
self.running = True
|
||
|
||
def request_service(self):
|
||
"""向协调服务器请求服务"""
|
||
message = {
|
||
'action': 'request',
|
||
'service_name': self.service_name,
|
||
'client_id': self.client_id
|
||
}
|
||
self.udp_sock.sendto(json.dumps(message).encode(), self.coordinator_addr)
|
||
|
||
# 等待响应
|
||
data, _ = self.udp_sock.recvfrom(4096)
|
||
response = json.loads(data.decode())
|
||
if response['status'] == 'success':
|
||
self.provider_addr = tuple(response['provider_addr'])
|
||
self.internal_port = response['internal_port']
|
||
self.target_id = response['target_id']
|
||
print(f"找到服务提供者: {self.provider_addr}, 端口: {self.internal_port}")
|
||
return True
|
||
else:
|
||
print(f"服务请求失败: {response['message']}")
|
||
return False
|
||
|
||
def punch_hole(self):
|
||
"""执行UDP打洞"""
|
||
# 请求打洞
|
||
message = {
|
||
'action': 'punch_request',
|
||
'client_id': self.client_id,
|
||
'target_id': self.target_id
|
||
}
|
||
self.udp_sock.sendto(json.dumps(message).encode(), self.coordinator_addr)
|
||
|
||
# 等待协调服务器响应
|
||
data, _ = self.udp_sock.recvfrom(4096)
|
||
response = json.loads(data.decode())
|
||
if response['status'] != 'success':
|
||
print(f"打洞请求失败: {response['message']}")
|
||
return False
|
||
|
||
# 向服务提供者发送打洞包
|
||
print(f"尝试打洞到 {self.provider_addr}...")
|
||
for _ in range(10): # 增加打洞尝试次数
|
||
self.udp_sock.sendto(b'PUNCH', self.provider_addr) # 发送原始UDP包
|
||
time.sleep(0.2)
|
||
|
||
# 检查连通性
|
||
self.udp_sock.settimeout(5.0)
|
||
try:
|
||
self.udp_sock.sendto(json.dumps({'action': 'punch_check'}).encode(), self.provider_addr)
|
||
data, addr = self.udp_sock.recvfrom(4096)
|
||
if addr == self.provider_addr:
|
||
print("打洞成功!")
|
||
return True
|
||
else:
|
||
print(f"错误的响应来源: {addr}")
|
||
return False
|
||
except socket.timeout:
|
||
print("打洞失败: 未收到响应")
|
||
return False
|
||
finally:
|
||
self.udp_sock.settimeout(0.1)
|
||
|
||
def udp_listener(self):
|
||
"""监听UDP消息"""
|
||
while self.running:
|
||
try:
|
||
data, addr = self.udp_sock.recvfrom(65535)
|
||
print(f"收到UDP消息: {data.decode()}")
|
||
try:
|
||
message = json.loads(data.decode())
|
||
action = message.get('action')
|
||
|
||
if action == 'punch_response':
|
||
# 打洞响应
|
||
print(f"收到打洞响应: {message}")
|
||
pass
|
||
elif action == 'connected':
|
||
# 连接建立成功
|
||
print(f"连接建立成功: {message}")
|
||
conn_id = message['conn_id']
|
||
if conn_id in self.active_connections:
|
||
# 创建可靠通道
|
||
channel = ReliableChannel(self.udp_sock, self.provider_addr, conn_id)
|
||
self.reliable_channels[conn_id] = channel
|
||
|
||
# 启动通道监控
|
||
threading.Thread(
|
||
target=self.monitor_channel,
|
||
args=(conn_id, channel),
|
||
daemon=True
|
||
).start()
|
||
|
||
elif action == 'data':
|
||
# 处理数据包
|
||
conn_id = message.get('conn_id')
|
||
if conn_id and conn_id in self.reliable_channels:
|
||
channel = self.reliable_channels[conn_id]
|
||
packet = message['packet']
|
||
|
||
# 处理数据包
|
||
data_chunk = channel.process_packet(packet)
|
||
if data_chunk and conn_id in self.active_connections:
|
||
# 转发数据到本地客户端
|
||
print(f"转发数据到本地客户端: {data_chunk}")
|
||
self.active_connections[conn_id].send(data_chunk)
|
||
elif action == 'ack':
|
||
# 处理ACK确认
|
||
conn_id = message.get('conn_id')
|
||
if conn_id and conn_id in self.reliable_channels:
|
||
channel = self.reliable_channels[conn_id]
|
||
ack_seq = message['ack_seq']
|
||
channel.process_ack(ack_seq)
|
||
elif action == 'stop_conn':
|
||
conn_id = message['conn_id']
|
||
self.close_connection(conn_id)
|
||
elif action == 'connect_failed':
|
||
conn_id = message['conn_id']
|
||
print(f"连接失败: {message['message']}")
|
||
self.close_connection(conn_id)
|
||
else:
|
||
print(f"未知动作: {action}")
|
||
except json.JSONDecodeError:
|
||
print("JSON解析错误")
|
||
# 原始UDP包可能是打洞确认
|
||
pass
|
||
except Exception as e:
|
||
print(f"UDP监听错误: {str(e)}")
|
||
except socket.timeout:
|
||
pass
|
||
except Exception as e:
|
||
print(f"UDP监听异常: {str(e)}")
|
||
|
||
def tcp_listener(self):
|
||
"""监听本地TCP连接"""
|
||
while self.running:
|
||
try:
|
||
client_sock, client_addr = self.tcp_sock.accept()
|
||
client_sock.settimeout(10.0) # 设置超时
|
||
print(f"新的本地连接来自 {client_addr}")
|
||
|
||
# 为连接生成唯一ID
|
||
conn_id = str(uuid.uuid4())
|
||
self.active_connections[conn_id] = client_sock
|
||
|
||
# 请求服务提供者建立连接
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'connect',
|
||
'client_id': self.client_id,
|
||
'conn_id': conn_id
|
||
}).encode(), self.provider_addr)
|
||
|
||
# 启动数据转发
|
||
threading.Thread(
|
||
target=self.forward_data,
|
||
args=(conn_id, client_sock),
|
||
daemon=True
|
||
).start()
|
||
|
||
except socket.timeout:
|
||
pass
|
||
except Exception as e:
|
||
print(f"TCP监听错误: {str(e)}")
|
||
|
||
def forward_data(self, conn_id, client_sock):
|
||
"""转发本地TCP数据到UDP通道"""
|
||
try:
|
||
while conn_id not in self.reliable_channels:
|
||
time.sleep(0.01)
|
||
while conn_id in self.active_connections:
|
||
# 从本地客户端读取数据
|
||
data = client_sock.recv(32768) # 32KB数据块
|
||
if not data:
|
||
break
|
||
|
||
# 通过可靠通道发送
|
||
if conn_id in self.reliable_channels:
|
||
channel = self.reliable_channels[conn_id]
|
||
print(f"发送数据: {data}")
|
||
channel.send(data)
|
||
except socket.timeout:
|
||
pass
|
||
except Exception as e:
|
||
print(f"数据转发错误: {str(e)}")
|
||
finally:
|
||
self.close_connection(conn_id)
|
||
|
||
def monitor_channel(self, conn_id, channel):
|
||
"""监控通道状态和重传"""
|
||
while conn_id in self.reliable_channels:
|
||
try:
|
||
channel.check_retransmit()
|
||
# 发送定期ACK
|
||
if time.time() - channel.last_ack_time > channel.ack_interval:
|
||
channel.send_ack()
|
||
channel.last_ack_time = time.time()
|
||
time.sleep(0.1)
|
||
except Exception as e:
|
||
print(f"通道监控错误: {str(e)}")
|
||
break
|
||
|
||
def close_connection(self, conn_id):
|
||
"""关闭指定连接"""
|
||
if conn_id in self.active_connections:
|
||
sock = self.active_connections.pop(conn_id)
|
||
sock.close()
|
||
if conn_id in self.reliable_channels:
|
||
channel = self.reliable_channels.pop(conn_id)
|
||
channel.close()
|
||
|
||
# 通知服务提供端关闭连接
|
||
if self.provider_addr:
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'stop_conn',
|
||
'conn_id': conn_id
|
||
}).encode(), self.provider_addr)
|
||
|
||
print(f"连接 {conn_id} 已关闭")
|
||
|
||
def run(self):
|
||
"""运行服务连接端"""
|
||
if not self.request_service():
|
||
return
|
||
if not self.punch_hole():
|
||
return
|
||
|
||
threading.Thread(target=self.udp_listener, daemon=True).start()
|
||
threading.Thread(target=self.tcp_listener, daemon=True).start()
|
||
|
||
try:
|
||
while self.running:
|
||
time.sleep(1)
|
||
except KeyboardInterrupt:
|
||
self.shutdown()
|
||
|
||
def shutdown(self):
|
||
"""关闭服务连接端"""
|
||
self.running = False
|
||
# 通知服务提供端停止所有连接
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'stop_client',
|
||
'client_id': self.client_id
|
||
}).encode(), self.provider_addr)
|
||
# 关闭所有连接
|
||
for conn_id in list(self.active_connections.keys()):
|
||
self.close_connection(conn_id)
|
||
self.udp_sock.close()
|
||
self.tcp_sock.close()
|
||
print("服务连接端已停止")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
COORDINATOR_ADDR = ('www.awin-x.top', 5000)
|
||
SERVICE_NAME = "my_service"
|
||
LOCAL_PORT = 12345
|
||
|
||
connector = ServiceConnector(COORDINATOR_ADDR, SERVICE_NAME, LOCAL_PORT)
|
||
connector.run()
|