提高稳定性

This commit is contained in:
awin-x 2025-05-31 17:52:01 +08:00
parent 1147ef5a89
commit a2364352dd
2 changed files with 47 additions and 53 deletions

View File

@ -11,7 +11,7 @@ from collections import deque
class ReliableChannel:
"""可靠传输通道类与provider中的类似"""
def __init__(self, udp_sock, remote_addr):
def __init__(self, udp_sock, remote_addr, conn_id):
self.udp_sock = udp_sock
self.remote_addr = remote_addr
self.send_buffer = deque()
@ -20,10 +20,11 @@ class ReliableChannel:
self.expected_seq = 0
self.last_ack = -1
self.ack_interval = 0.2
self.conn_id = conn_id
self.last_ack_time = 0
self.window_size = 10
self.retransmit_timers = {}
self.retransmit_timeout = 0.5
self.retransmit_timeout = 1
self.running = True
def send(self, data, is_control=False):
@ -50,11 +51,14 @@ class ReliableChannel:
for packet in list(self.send_buffer)[:self.window_size]:
self.udp_sock.sendto(json.dumps({
'action': 'data',
'packet': packet
'packet': packet,
'conn_id': self.conn_id
}).encode(), self.remote_addr)
def process_ack(self, ack_seq):
print(f"收到ack: {ack_seq}")
if ack_seq > self.last_ack:
print(f"更新last_ack: {ack_seq}")
self.last_ack = ack_seq
for seq in list(self.retransmit_timers.keys()):
if seq <= ack_seq:
@ -65,9 +69,9 @@ class ReliableChannel:
seq = packet['seq']
data = bytes.fromhex(packet['data'])
if time.time() - self.last_ack_time > self.ack_interval:
self.send_ack()
self.last_ack_time = time.time()
# if time.time() - self.last_ack_time > self.ack_interval:
self.send_ack()
self.last_ack_time = time.time()
if seq == self.expected_seq:
self.expected_seq += 1
@ -82,8 +86,8 @@ class ReliableChannel:
def _process_buffered(self):
while self.expected_seq in self.recv_buffer:
data = self.recv_buffer.pop(self.expected_seq)
self.expected_seq += 1
self.recv_queue.append(data) # 添加到队列
# self.expected_seq += 1
# self.recv_queue.append(data) # 添加到队列
def get_received_data(self):
"""从接收队列中取出数据"""
@ -93,7 +97,8 @@ class ReliableChannel:
ack_packet = {
'action': 'ack',
'ack_seq': self.expected_seq - 1,
'window': self.window_size
'window': self.window_size,
'conn_id': self.conn_id
}
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
@ -101,11 +106,13 @@ class ReliableChannel:
now = time.time()
for seq, send_time in list(self.retransmit_timers.items()):
if now - send_time > self.retransmit_timeout:
print(f'有重传{seq}')
for packet in self.send_buffer:
if packet['seq'] == seq:
self.udp_sock.sendto(json.dumps({
'action': 'data',
'packet': packet
'packet': packet,
'conn_id': self.conn_id
}).encode(), self.remote_addr)
self.retransmit_timers[seq] = now
break
@ -128,7 +135,7 @@ class ServiceConnector:
self.udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.udp_sock.bind(('0.0.0.0', 0))
self.udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024) # 1MB接收缓冲区
self.udp_sock.settimeout(0.1)
self.udp_sock.settimeout(5)
# 创建TCP监听套接字
self.tcp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -195,7 +202,7 @@ class ServiceConnector:
self.udp_sock.sendto(json.dumps({'action': 'punch_check'}).encode(), self.provider_addr)
data, addr = self.udp_sock.recvfrom(4096)
if addr == self.provider_addr:
print("打洞成功! 已建立UDP连接")
print("打洞成功!")
return True
else:
print(f"错误的响应来源: {addr}")
@ -211,19 +218,22 @@ class ServiceConnector:
while self.running:
try:
data, addr = self.udp_sock.recvfrom(65535)
print(f"收到UDP消息: {data.decode()}")
try:
message = json.loads(data.decode())
action = message.get('action')
if action == 'punch_response':
# 打洞响应
print(f"收到打洞响应: {message}")
pass
elif action == 'connected':
# 连接建立成功
print(f"连接建立成功: {message}")
conn_id = message['conn_id']
if conn_id in self.active_connections:
# 创建可靠通道
channel = ReliableChannel(self.udp_sock, self.provider_addr)
channel = ReliableChannel(self.udp_sock, self.provider_addr, conn_id)
self.reliable_channels[conn_id] = channel
# 启动通道监控
@ -233,7 +243,6 @@ class ServiceConnector:
daemon=True
).start()
elif action == 'data':
# 处理数据包
conn_id = message.get('conn_id')
@ -245,6 +254,7 @@ class ServiceConnector:
data_chunk = channel.process_packet(packet)
if data_chunk and conn_id in self.active_connections:
# 转发数据到本地客户端
print(f"转发数据到本地客户端: {data_chunk}")
self.active_connections[conn_id].send(data_chunk)
elif action == 'ack':
# 处理ACK确认
@ -260,7 +270,10 @@ class ServiceConnector:
conn_id = message['conn_id']
print(f"连接失败: {message['message']}")
self.close_connection(conn_id)
else:
print(f"未知动作: {action}")
except json.JSONDecodeError:
print("JSON解析错误")
# 原始UDP包可能是打洞确认
pass
except Exception as e:
@ -296,30 +309,16 @@ class ServiceConnector:
daemon=True
).start()
threading.Thread(
target=self.forward_incoming_data,
args=(conn_id, client_sock),
daemon=True
).start()
except socket.timeout:
pass
except Exception as e:
print(f"TCP监听错误: {str(e)}")
def forward_incoming_data(self, conn_id, local_sock):
while conn_id in self.active_connections:
if conn_id in self.reliable_channels:
channel = self.reliable_channels[conn_id]
data = channel.get_received_data()
if data:
local_sock.sendall(data)
else:
time.sleep(0.1)
def forward_data(self, conn_id, client_sock):
"""转发本地TCP数据到UDP通道"""
try:
while conn_id not in self.reliable_channels:
time.sleep(0.01)
while conn_id in self.active_connections:
# 从本地客户端读取数据
data = client_sock.recv(32768) # 32KB数据块
@ -329,6 +328,7 @@ class ServiceConnector:
# 通过可靠通道发送
if conn_id in self.reliable_channels:
channel = self.reliable_channels[conn_id]
print(f"发送数据: {data}")
channel.send(data)
except socket.timeout:
pass

View File

@ -9,7 +9,7 @@ from collections import deque
class ReliableChannel:
"""可靠传输通道类实现序列号、ACK确认和重传"""
def __init__(self, udp_sock, remote_addr):
def __init__(self, udp_sock, remote_addr, conn_id):
self.udp_sock = udp_sock
self.remote_addr = remote_addr
self.send_buffer = deque() # 发送缓冲区
@ -22,10 +22,12 @@ class ReliableChannel:
self.window_size = 10 # 滑动窗口大小
self.retransmit_timers = {} # 重传计时器 {seq: send_time}
self.retransmit_timeout = 0.5 # 重传超时时间 (秒)
self.conn_id = conn_id
self.running = True
def send(self, data, is_control=False):
"""发送数据并管理发送缓冲区"""
print(f"发送数据: {data}")
if not self.running:
return
@ -56,11 +58,13 @@ class ReliableChannel:
for packet in list(self.send_buffer)[:self.window_size]:
self.udp_sock.sendto(json.dumps({
'action': 'data',
'packet': packet
'packet': packet,
'conn_id': self.conn_id
}).encode(), self.remote_addr)
def process_ack(self, ack_seq):
"""处理接收到的ACK"""
print(f"收到ACK: {ack_seq}")
if ack_seq > self.last_ack:
self.last_ack = ack_seq
# 清除重传计时器
@ -98,8 +102,8 @@ class ReliableChannel:
"""处理接收缓冲区中的连续数据"""
while self.expected_seq in self.recv_buffer:
data = self.recv_buffer.pop(self.expected_seq)
self.expected_seq += 1
self.recv_queue.append(data) # 添加到队列
# self.expected_seq += 1
# self.recv_queue.append(data) # 添加到队列
def get_received_data(self):
"""从接收队列中取出数据"""
@ -110,7 +114,8 @@ class ReliableChannel:
ack_packet = {
'action': 'ack',
'ack_seq': self.expected_seq - 1, # 确认到目前收到的最大连续序列号
'window': self.window_size
'window': self.window_size,
'conn_id': self.conn_id
}
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
@ -124,7 +129,8 @@ class ReliableChannel:
if packet['seq'] == seq:
self.udp_sock.sendto(json.dumps({
'action': 'data',
'packet': packet
'packet': packet,
'conn_id': self.conn_id
}).encode(), self.remote_addr)
self.retransmit_timers[seq] = now
break
@ -181,6 +187,7 @@ class ServiceProvider:
while self.running:
try:
data, addr = self.udp_sock.recvfrom(65535)
print(f"收到UDP消息: {data.decode()}")
try:
message = json.loads(data.decode())
action = message.get('action')
@ -216,6 +223,7 @@ class ServiceProvider:
data_chunk = channel.process_packet(packet)
if data_chunk and conn_id in self.active_connections:
# 转发数据到本地服务
print(f"转发数据包到本地服务: {data_chunk[:10]}")
self.active_connections[conn_id]['local_sock'].sendall(data_chunk)
elif action == 'ack':
# 处理ACK确认
@ -249,7 +257,7 @@ class ServiceProvider:
local_sock.connect(('127.0.0.1', self.internal_port))
# 创建可靠通道
channel = ReliableChannel(self.udp_sock, client_addr)
channel = ReliableChannel(self.udp_sock, client_addr, conn_id)
self.reliable_channels[conn_id] = channel
# 存储连接
@ -273,12 +281,6 @@ class ServiceProvider:
daemon=True
).start()
threading.Thread(
target=self.forward_incoming_data,
args=(conn_id, channel, local_sock),
daemon=True
).start()
threading.Thread(
target=self.monitor_channel,
args=(conn_id, channel),
@ -294,14 +296,6 @@ class ServiceProvider:
'message': str(e)
}).encode(), client_addr)
def forward_incoming_data(self, conn_id, channel, local_sock):
while conn_id in self.active_connections:
data = channel.get_received_data()
if data:
local_sock.sendall(data)
else:
time.sleep(0.02)
def forward_data(self, conn_id, local_sock, channel):
"""转发TCP数据到UDP通道"""
try:
@ -387,7 +381,7 @@ class ServiceProvider:
if __name__ == '__main__':
COORDINATOR_ADDR = ('www.awin-x.top', 5000)
SERVICE_NAME = "my_service"
INTERNAL_PORT = 5001
INTERNAL_PORT = 22
provider = ServiceProvider(COORDINATOR_ADDR, SERVICE_NAME, INTERNAL_PORT)
provider.run()