提高稳定性
This commit is contained in:
parent
1147ef5a89
commit
a2364352dd
@ -11,7 +11,7 @@ from collections import deque
|
||||
class ReliableChannel:
|
||||
"""可靠传输通道类(与provider中的类似)"""
|
||||
|
||||
def __init__(self, udp_sock, remote_addr):
|
||||
def __init__(self, udp_sock, remote_addr, conn_id):
|
||||
self.udp_sock = udp_sock
|
||||
self.remote_addr = remote_addr
|
||||
self.send_buffer = deque()
|
||||
@ -20,10 +20,11 @@ class ReliableChannel:
|
||||
self.expected_seq = 0
|
||||
self.last_ack = -1
|
||||
self.ack_interval = 0.2
|
||||
self.conn_id = conn_id
|
||||
self.last_ack_time = 0
|
||||
self.window_size = 10
|
||||
self.retransmit_timers = {}
|
||||
self.retransmit_timeout = 0.5
|
||||
self.retransmit_timeout = 1
|
||||
self.running = True
|
||||
|
||||
def send(self, data, is_control=False):
|
||||
@ -50,11 +51,14 @@ class ReliableChannel:
|
||||
for packet in list(self.send_buffer)[:self.window_size]:
|
||||
self.udp_sock.sendto(json.dumps({
|
||||
'action': 'data',
|
||||
'packet': packet
|
||||
'packet': packet,
|
||||
'conn_id': self.conn_id
|
||||
}).encode(), self.remote_addr)
|
||||
|
||||
def process_ack(self, ack_seq):
|
||||
print(f"收到ack: {ack_seq}")
|
||||
if ack_seq > self.last_ack:
|
||||
print(f"更新last_ack: {ack_seq}")
|
||||
self.last_ack = ack_seq
|
||||
for seq in list(self.retransmit_timers.keys()):
|
||||
if seq <= ack_seq:
|
||||
@ -65,9 +69,9 @@ class ReliableChannel:
|
||||
seq = packet['seq']
|
||||
data = bytes.fromhex(packet['data'])
|
||||
|
||||
if time.time() - self.last_ack_time > self.ack_interval:
|
||||
self.send_ack()
|
||||
self.last_ack_time = time.time()
|
||||
# if time.time() - self.last_ack_time > self.ack_interval:
|
||||
self.send_ack()
|
||||
self.last_ack_time = time.time()
|
||||
|
||||
if seq == self.expected_seq:
|
||||
self.expected_seq += 1
|
||||
@ -82,8 +86,8 @@ class ReliableChannel:
|
||||
def _process_buffered(self):
|
||||
while self.expected_seq in self.recv_buffer:
|
||||
data = self.recv_buffer.pop(self.expected_seq)
|
||||
self.expected_seq += 1
|
||||
self.recv_queue.append(data) # 添加到队列
|
||||
# self.expected_seq += 1
|
||||
# self.recv_queue.append(data) # 添加到队列
|
||||
|
||||
def get_received_data(self):
|
||||
"""从接收队列中取出数据"""
|
||||
@ -93,7 +97,8 @@ class ReliableChannel:
|
||||
ack_packet = {
|
||||
'action': 'ack',
|
||||
'ack_seq': self.expected_seq - 1,
|
||||
'window': self.window_size
|
||||
'window': self.window_size,
|
||||
'conn_id': self.conn_id
|
||||
}
|
||||
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
|
||||
|
||||
@ -101,11 +106,13 @@ class ReliableChannel:
|
||||
now = time.time()
|
||||
for seq, send_time in list(self.retransmit_timers.items()):
|
||||
if now - send_time > self.retransmit_timeout:
|
||||
print(f'有重传{seq}')
|
||||
for packet in self.send_buffer:
|
||||
if packet['seq'] == seq:
|
||||
self.udp_sock.sendto(json.dumps({
|
||||
'action': 'data',
|
||||
'packet': packet
|
||||
'packet': packet,
|
||||
'conn_id': self.conn_id
|
||||
}).encode(), self.remote_addr)
|
||||
self.retransmit_timers[seq] = now
|
||||
break
|
||||
@ -128,7 +135,7 @@ class ServiceConnector:
|
||||
self.udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.udp_sock.bind(('0.0.0.0', 0))
|
||||
self.udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024) # 1MB接收缓冲区
|
||||
self.udp_sock.settimeout(0.1)
|
||||
self.udp_sock.settimeout(5)
|
||||
|
||||
# 创建TCP监听套接字
|
||||
self.tcp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
@ -195,7 +202,7 @@ class ServiceConnector:
|
||||
self.udp_sock.sendto(json.dumps({'action': 'punch_check'}).encode(), self.provider_addr)
|
||||
data, addr = self.udp_sock.recvfrom(4096)
|
||||
if addr == self.provider_addr:
|
||||
print("打洞成功! 已建立UDP连接")
|
||||
print("打洞成功!")
|
||||
return True
|
||||
else:
|
||||
print(f"错误的响应来源: {addr}")
|
||||
@ -211,19 +218,22 @@ class ServiceConnector:
|
||||
while self.running:
|
||||
try:
|
||||
data, addr = self.udp_sock.recvfrom(65535)
|
||||
print(f"收到UDP消息: {data.decode()}")
|
||||
try:
|
||||
message = json.loads(data.decode())
|
||||
action = message.get('action')
|
||||
|
||||
if action == 'punch_response':
|
||||
# 打洞响应
|
||||
print(f"收到打洞响应: {message}")
|
||||
pass
|
||||
elif action == 'connected':
|
||||
# 连接建立成功
|
||||
print(f"连接建立成功: {message}")
|
||||
conn_id = message['conn_id']
|
||||
if conn_id in self.active_connections:
|
||||
# 创建可靠通道
|
||||
channel = ReliableChannel(self.udp_sock, self.provider_addr)
|
||||
channel = ReliableChannel(self.udp_sock, self.provider_addr, conn_id)
|
||||
self.reliable_channels[conn_id] = channel
|
||||
|
||||
# 启动通道监控
|
||||
@ -233,7 +243,6 @@ class ServiceConnector:
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
|
||||
elif action == 'data':
|
||||
# 处理数据包
|
||||
conn_id = message.get('conn_id')
|
||||
@ -245,6 +254,7 @@ class ServiceConnector:
|
||||
data_chunk = channel.process_packet(packet)
|
||||
if data_chunk and conn_id in self.active_connections:
|
||||
# 转发数据到本地客户端
|
||||
print(f"转发数据到本地客户端: {data_chunk}")
|
||||
self.active_connections[conn_id].send(data_chunk)
|
||||
elif action == 'ack':
|
||||
# 处理ACK确认
|
||||
@ -260,7 +270,10 @@ class ServiceConnector:
|
||||
conn_id = message['conn_id']
|
||||
print(f"连接失败: {message['message']}")
|
||||
self.close_connection(conn_id)
|
||||
else:
|
||||
print(f"未知动作: {action}")
|
||||
except json.JSONDecodeError:
|
||||
print("JSON解析错误")
|
||||
# 原始UDP包可能是打洞确认
|
||||
pass
|
||||
except Exception as e:
|
||||
@ -296,30 +309,16 @@ class ServiceConnector:
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
threading.Thread(
|
||||
target=self.forward_incoming_data,
|
||||
args=(conn_id, client_sock),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
except socket.timeout:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"TCP监听错误: {str(e)}")
|
||||
|
||||
def forward_incoming_data(self, conn_id, local_sock):
|
||||
while conn_id in self.active_connections:
|
||||
if conn_id in self.reliable_channels:
|
||||
channel = self.reliable_channels[conn_id]
|
||||
data = channel.get_received_data()
|
||||
if data:
|
||||
local_sock.sendall(data)
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
def forward_data(self, conn_id, client_sock):
|
||||
"""转发本地TCP数据到UDP通道"""
|
||||
try:
|
||||
while conn_id not in self.reliable_channels:
|
||||
time.sleep(0.01)
|
||||
while conn_id in self.active_connections:
|
||||
# 从本地客户端读取数据
|
||||
data = client_sock.recv(32768) # 32KB数据块
|
||||
@ -329,6 +328,7 @@ class ServiceConnector:
|
||||
# 通过可靠通道发送
|
||||
if conn_id in self.reliable_channels:
|
||||
channel = self.reliable_channels[conn_id]
|
||||
print(f"发送数据: {data}")
|
||||
channel.send(data)
|
||||
except socket.timeout:
|
||||
pass
|
||||
@ -407,4 +407,4 @@ if __name__ == '__main__':
|
||||
LOCAL_PORT = 12345
|
||||
|
||||
connector = ServiceConnector(COORDINATOR_ADDR, SERVICE_NAME, LOCAL_PORT)
|
||||
connector.run()
|
||||
connector.run()
|
||||
|
||||
38
provider1.py
38
provider1.py
@ -9,7 +9,7 @@ from collections import deque
|
||||
class ReliableChannel:
|
||||
"""可靠传输通道类,实现序列号、ACK确认和重传"""
|
||||
|
||||
def __init__(self, udp_sock, remote_addr):
|
||||
def __init__(self, udp_sock, remote_addr, conn_id):
|
||||
self.udp_sock = udp_sock
|
||||
self.remote_addr = remote_addr
|
||||
self.send_buffer = deque() # 发送缓冲区
|
||||
@ -22,10 +22,12 @@ class ReliableChannel:
|
||||
self.window_size = 10 # 滑动窗口大小
|
||||
self.retransmit_timers = {} # 重传计时器 {seq: send_time}
|
||||
self.retransmit_timeout = 0.5 # 重传超时时间 (秒)
|
||||
self.conn_id = conn_id
|
||||
self.running = True
|
||||
|
||||
def send(self, data, is_control=False):
|
||||
"""发送数据并管理发送缓冲区"""
|
||||
print(f"发送数据: {data}")
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
@ -56,11 +58,13 @@ class ReliableChannel:
|
||||
for packet in list(self.send_buffer)[:self.window_size]:
|
||||
self.udp_sock.sendto(json.dumps({
|
||||
'action': 'data',
|
||||
'packet': packet
|
||||
'packet': packet,
|
||||
'conn_id': self.conn_id
|
||||
}).encode(), self.remote_addr)
|
||||
|
||||
def process_ack(self, ack_seq):
|
||||
"""处理接收到的ACK"""
|
||||
print(f"收到ACK: {ack_seq}")
|
||||
if ack_seq > self.last_ack:
|
||||
self.last_ack = ack_seq
|
||||
# 清除重传计时器
|
||||
@ -98,8 +102,8 @@ class ReliableChannel:
|
||||
"""处理接收缓冲区中的连续数据"""
|
||||
while self.expected_seq in self.recv_buffer:
|
||||
data = self.recv_buffer.pop(self.expected_seq)
|
||||
self.expected_seq += 1
|
||||
self.recv_queue.append(data) # 添加到队列
|
||||
# self.expected_seq += 1
|
||||
# self.recv_queue.append(data) # 添加到队列
|
||||
|
||||
def get_received_data(self):
|
||||
"""从接收队列中取出数据"""
|
||||
@ -110,7 +114,8 @@ class ReliableChannel:
|
||||
ack_packet = {
|
||||
'action': 'ack',
|
||||
'ack_seq': self.expected_seq - 1, # 确认到目前收到的最大连续序列号
|
||||
'window': self.window_size
|
||||
'window': self.window_size,
|
||||
'conn_id': self.conn_id
|
||||
}
|
||||
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
|
||||
|
||||
@ -124,7 +129,8 @@ class ReliableChannel:
|
||||
if packet['seq'] == seq:
|
||||
self.udp_sock.sendto(json.dumps({
|
||||
'action': 'data',
|
||||
'packet': packet
|
||||
'packet': packet,
|
||||
'conn_id': self.conn_id
|
||||
}).encode(), self.remote_addr)
|
||||
self.retransmit_timers[seq] = now
|
||||
break
|
||||
@ -181,6 +187,7 @@ class ServiceProvider:
|
||||
while self.running:
|
||||
try:
|
||||
data, addr = self.udp_sock.recvfrom(65535)
|
||||
print(f"收到UDP消息: {data.decode()}")
|
||||
try:
|
||||
message = json.loads(data.decode())
|
||||
action = message.get('action')
|
||||
@ -216,6 +223,7 @@ class ServiceProvider:
|
||||
data_chunk = channel.process_packet(packet)
|
||||
if data_chunk and conn_id in self.active_connections:
|
||||
# 转发数据到本地服务
|
||||
print(f"转发数据包到本地服务: {data_chunk[:10]}")
|
||||
self.active_connections[conn_id]['local_sock'].sendall(data_chunk)
|
||||
elif action == 'ack':
|
||||
# 处理ACK确认
|
||||
@ -249,7 +257,7 @@ class ServiceProvider:
|
||||
local_sock.connect(('127.0.0.1', self.internal_port))
|
||||
|
||||
# 创建可靠通道
|
||||
channel = ReliableChannel(self.udp_sock, client_addr)
|
||||
channel = ReliableChannel(self.udp_sock, client_addr, conn_id)
|
||||
self.reliable_channels[conn_id] = channel
|
||||
|
||||
# 存储连接
|
||||
@ -273,12 +281,6 @@ class ServiceProvider:
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
threading.Thread(
|
||||
target=self.forward_incoming_data,
|
||||
args=(conn_id, channel, local_sock),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
threading.Thread(
|
||||
target=self.monitor_channel,
|
||||
args=(conn_id, channel),
|
||||
@ -294,14 +296,6 @@ class ServiceProvider:
|
||||
'message': str(e)
|
||||
}).encode(), client_addr)
|
||||
|
||||
def forward_incoming_data(self, conn_id, channel, local_sock):
|
||||
while conn_id in self.active_connections:
|
||||
data = channel.get_received_data()
|
||||
if data:
|
||||
local_sock.sendall(data)
|
||||
else:
|
||||
time.sleep(0.02)
|
||||
|
||||
def forward_data(self, conn_id, local_sock, channel):
|
||||
"""转发TCP数据到UDP通道"""
|
||||
try:
|
||||
@ -387,7 +381,7 @@ class ServiceProvider:
|
||||
if __name__ == '__main__':
|
||||
COORDINATOR_ADDR = ('www.awin-x.top', 5000)
|
||||
SERVICE_NAME = "my_service"
|
||||
INTERNAL_PORT = 5001
|
||||
INTERNAL_PORT = 22
|
||||
|
||||
provider = ServiceProvider(COORDINATOR_ADDR, SERVICE_NAME, INTERNAL_PORT)
|
||||
provider.run()
|
||||
Loading…
Reference in New Issue
Block a user