提高稳定性
This commit is contained in:
parent
1147ef5a89
commit
a2364352dd
@ -11,7 +11,7 @@ from collections import deque
|
|||||||
class ReliableChannel:
|
class ReliableChannel:
|
||||||
"""可靠传输通道类(与provider中的类似)"""
|
"""可靠传输通道类(与provider中的类似)"""
|
||||||
|
|
||||||
def __init__(self, udp_sock, remote_addr):
|
def __init__(self, udp_sock, remote_addr, conn_id):
|
||||||
self.udp_sock = udp_sock
|
self.udp_sock = udp_sock
|
||||||
self.remote_addr = remote_addr
|
self.remote_addr = remote_addr
|
||||||
self.send_buffer = deque()
|
self.send_buffer = deque()
|
||||||
@ -20,10 +20,11 @@ class ReliableChannel:
|
|||||||
self.expected_seq = 0
|
self.expected_seq = 0
|
||||||
self.last_ack = -1
|
self.last_ack = -1
|
||||||
self.ack_interval = 0.2
|
self.ack_interval = 0.2
|
||||||
|
self.conn_id = conn_id
|
||||||
self.last_ack_time = 0
|
self.last_ack_time = 0
|
||||||
self.window_size = 10
|
self.window_size = 10
|
||||||
self.retransmit_timers = {}
|
self.retransmit_timers = {}
|
||||||
self.retransmit_timeout = 0.5
|
self.retransmit_timeout = 1
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
def send(self, data, is_control=False):
|
def send(self, data, is_control=False):
|
||||||
@ -50,11 +51,14 @@ class ReliableChannel:
|
|||||||
for packet in list(self.send_buffer)[:self.window_size]:
|
for packet in list(self.send_buffer)[:self.window_size]:
|
||||||
self.udp_sock.sendto(json.dumps({
|
self.udp_sock.sendto(json.dumps({
|
||||||
'action': 'data',
|
'action': 'data',
|
||||||
'packet': packet
|
'packet': packet,
|
||||||
|
'conn_id': self.conn_id
|
||||||
}).encode(), self.remote_addr)
|
}).encode(), self.remote_addr)
|
||||||
|
|
||||||
def process_ack(self, ack_seq):
|
def process_ack(self, ack_seq):
|
||||||
|
print(f"收到ack: {ack_seq}")
|
||||||
if ack_seq > self.last_ack:
|
if ack_seq > self.last_ack:
|
||||||
|
print(f"更新last_ack: {ack_seq}")
|
||||||
self.last_ack = ack_seq
|
self.last_ack = ack_seq
|
||||||
for seq in list(self.retransmit_timers.keys()):
|
for seq in list(self.retransmit_timers.keys()):
|
||||||
if seq <= ack_seq:
|
if seq <= ack_seq:
|
||||||
@ -65,9 +69,9 @@ class ReliableChannel:
|
|||||||
seq = packet['seq']
|
seq = packet['seq']
|
||||||
data = bytes.fromhex(packet['data'])
|
data = bytes.fromhex(packet['data'])
|
||||||
|
|
||||||
if time.time() - self.last_ack_time > self.ack_interval:
|
# if time.time() - self.last_ack_time > self.ack_interval:
|
||||||
self.send_ack()
|
self.send_ack()
|
||||||
self.last_ack_time = time.time()
|
self.last_ack_time = time.time()
|
||||||
|
|
||||||
if seq == self.expected_seq:
|
if seq == self.expected_seq:
|
||||||
self.expected_seq += 1
|
self.expected_seq += 1
|
||||||
@ -82,8 +86,8 @@ class ReliableChannel:
|
|||||||
def _process_buffered(self):
|
def _process_buffered(self):
|
||||||
while self.expected_seq in self.recv_buffer:
|
while self.expected_seq in self.recv_buffer:
|
||||||
data = self.recv_buffer.pop(self.expected_seq)
|
data = self.recv_buffer.pop(self.expected_seq)
|
||||||
self.expected_seq += 1
|
# self.expected_seq += 1
|
||||||
self.recv_queue.append(data) # 添加到队列
|
# self.recv_queue.append(data) # 添加到队列
|
||||||
|
|
||||||
def get_received_data(self):
|
def get_received_data(self):
|
||||||
"""从接收队列中取出数据"""
|
"""从接收队列中取出数据"""
|
||||||
@ -93,7 +97,8 @@ class ReliableChannel:
|
|||||||
ack_packet = {
|
ack_packet = {
|
||||||
'action': 'ack',
|
'action': 'ack',
|
||||||
'ack_seq': self.expected_seq - 1,
|
'ack_seq': self.expected_seq - 1,
|
||||||
'window': self.window_size
|
'window': self.window_size,
|
||||||
|
'conn_id': self.conn_id
|
||||||
}
|
}
|
||||||
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
|
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
|
||||||
|
|
||||||
@ -101,11 +106,13 @@ class ReliableChannel:
|
|||||||
now = time.time()
|
now = time.time()
|
||||||
for seq, send_time in list(self.retransmit_timers.items()):
|
for seq, send_time in list(self.retransmit_timers.items()):
|
||||||
if now - send_time > self.retransmit_timeout:
|
if now - send_time > self.retransmit_timeout:
|
||||||
|
print(f'有重传{seq}')
|
||||||
for packet in self.send_buffer:
|
for packet in self.send_buffer:
|
||||||
if packet['seq'] == seq:
|
if packet['seq'] == seq:
|
||||||
self.udp_sock.sendto(json.dumps({
|
self.udp_sock.sendto(json.dumps({
|
||||||
'action': 'data',
|
'action': 'data',
|
||||||
'packet': packet
|
'packet': packet,
|
||||||
|
'conn_id': self.conn_id
|
||||||
}).encode(), self.remote_addr)
|
}).encode(), self.remote_addr)
|
||||||
self.retransmit_timers[seq] = now
|
self.retransmit_timers[seq] = now
|
||||||
break
|
break
|
||||||
@ -128,7 +135,7 @@ class ServiceConnector:
|
|||||||
self.udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
self.udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
self.udp_sock.bind(('0.0.0.0', 0))
|
self.udp_sock.bind(('0.0.0.0', 0))
|
||||||
self.udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024) # 1MB接收缓冲区
|
self.udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024) # 1MB接收缓冲区
|
||||||
self.udp_sock.settimeout(0.1)
|
self.udp_sock.settimeout(5)
|
||||||
|
|
||||||
# 创建TCP监听套接字
|
# 创建TCP监听套接字
|
||||||
self.tcp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self.tcp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
@ -195,7 +202,7 @@ class ServiceConnector:
|
|||||||
self.udp_sock.sendto(json.dumps({'action': 'punch_check'}).encode(), self.provider_addr)
|
self.udp_sock.sendto(json.dumps({'action': 'punch_check'}).encode(), self.provider_addr)
|
||||||
data, addr = self.udp_sock.recvfrom(4096)
|
data, addr = self.udp_sock.recvfrom(4096)
|
||||||
if addr == self.provider_addr:
|
if addr == self.provider_addr:
|
||||||
print("打洞成功! 已建立UDP连接")
|
print("打洞成功!")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
print(f"错误的响应来源: {addr}")
|
print(f"错误的响应来源: {addr}")
|
||||||
@ -211,19 +218,22 @@ class ServiceConnector:
|
|||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
data, addr = self.udp_sock.recvfrom(65535)
|
data, addr = self.udp_sock.recvfrom(65535)
|
||||||
|
print(f"收到UDP消息: {data.decode()}")
|
||||||
try:
|
try:
|
||||||
message = json.loads(data.decode())
|
message = json.loads(data.decode())
|
||||||
action = message.get('action')
|
action = message.get('action')
|
||||||
|
|
||||||
if action == 'punch_response':
|
if action == 'punch_response':
|
||||||
# 打洞响应
|
# 打洞响应
|
||||||
|
print(f"收到打洞响应: {message}")
|
||||||
pass
|
pass
|
||||||
elif action == 'connected':
|
elif action == 'connected':
|
||||||
# 连接建立成功
|
# 连接建立成功
|
||||||
|
print(f"连接建立成功: {message}")
|
||||||
conn_id = message['conn_id']
|
conn_id = message['conn_id']
|
||||||
if conn_id in self.active_connections:
|
if conn_id in self.active_connections:
|
||||||
# 创建可靠通道
|
# 创建可靠通道
|
||||||
channel = ReliableChannel(self.udp_sock, self.provider_addr)
|
channel = ReliableChannel(self.udp_sock, self.provider_addr, conn_id)
|
||||||
self.reliable_channels[conn_id] = channel
|
self.reliable_channels[conn_id] = channel
|
||||||
|
|
||||||
# 启动通道监控
|
# 启动通道监控
|
||||||
@ -233,7 +243,6 @@ class ServiceConnector:
|
|||||||
daemon=True
|
daemon=True
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
|
|
||||||
elif action == 'data':
|
elif action == 'data':
|
||||||
# 处理数据包
|
# 处理数据包
|
||||||
conn_id = message.get('conn_id')
|
conn_id = message.get('conn_id')
|
||||||
@ -245,6 +254,7 @@ class ServiceConnector:
|
|||||||
data_chunk = channel.process_packet(packet)
|
data_chunk = channel.process_packet(packet)
|
||||||
if data_chunk and conn_id in self.active_connections:
|
if data_chunk and conn_id in self.active_connections:
|
||||||
# 转发数据到本地客户端
|
# 转发数据到本地客户端
|
||||||
|
print(f"转发数据到本地客户端: {data_chunk}")
|
||||||
self.active_connections[conn_id].send(data_chunk)
|
self.active_connections[conn_id].send(data_chunk)
|
||||||
elif action == 'ack':
|
elif action == 'ack':
|
||||||
# 处理ACK确认
|
# 处理ACK确认
|
||||||
@ -260,7 +270,10 @@ class ServiceConnector:
|
|||||||
conn_id = message['conn_id']
|
conn_id = message['conn_id']
|
||||||
print(f"连接失败: {message['message']}")
|
print(f"连接失败: {message['message']}")
|
||||||
self.close_connection(conn_id)
|
self.close_connection(conn_id)
|
||||||
|
else:
|
||||||
|
print(f"未知动作: {action}")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
print("JSON解析错误")
|
||||||
# 原始UDP包可能是打洞确认
|
# 原始UDP包可能是打洞确认
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -296,30 +309,16 @@ class ServiceConnector:
|
|||||||
daemon=True
|
daemon=True
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
threading.Thread(
|
|
||||||
target=self.forward_incoming_data,
|
|
||||||
args=(conn_id, client_sock),
|
|
||||||
daemon=True
|
|
||||||
).start()
|
|
||||||
|
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"TCP监听错误: {str(e)}")
|
print(f"TCP监听错误: {str(e)}")
|
||||||
|
|
||||||
def forward_incoming_data(self, conn_id, local_sock):
|
|
||||||
while conn_id in self.active_connections:
|
|
||||||
if conn_id in self.reliable_channels:
|
|
||||||
channel = self.reliable_channels[conn_id]
|
|
||||||
data = channel.get_received_data()
|
|
||||||
if data:
|
|
||||||
local_sock.sendall(data)
|
|
||||||
else:
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
def forward_data(self, conn_id, client_sock):
|
def forward_data(self, conn_id, client_sock):
|
||||||
"""转发本地TCP数据到UDP通道"""
|
"""转发本地TCP数据到UDP通道"""
|
||||||
try:
|
try:
|
||||||
|
while conn_id not in self.reliable_channels:
|
||||||
|
time.sleep(0.01)
|
||||||
while conn_id in self.active_connections:
|
while conn_id in self.active_connections:
|
||||||
# 从本地客户端读取数据
|
# 从本地客户端读取数据
|
||||||
data = client_sock.recv(32768) # 32KB数据块
|
data = client_sock.recv(32768) # 32KB数据块
|
||||||
@ -329,6 +328,7 @@ class ServiceConnector:
|
|||||||
# 通过可靠通道发送
|
# 通过可靠通道发送
|
||||||
if conn_id in self.reliable_channels:
|
if conn_id in self.reliable_channels:
|
||||||
channel = self.reliable_channels[conn_id]
|
channel = self.reliable_channels[conn_id]
|
||||||
|
print(f"发送数据: {data}")
|
||||||
channel.send(data)
|
channel.send(data)
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
pass
|
pass
|
||||||
|
|||||||
38
provider1.py
38
provider1.py
@ -9,7 +9,7 @@ from collections import deque
|
|||||||
class ReliableChannel:
|
class ReliableChannel:
|
||||||
"""可靠传输通道类,实现序列号、ACK确认和重传"""
|
"""可靠传输通道类,实现序列号、ACK确认和重传"""
|
||||||
|
|
||||||
def __init__(self, udp_sock, remote_addr):
|
def __init__(self, udp_sock, remote_addr, conn_id):
|
||||||
self.udp_sock = udp_sock
|
self.udp_sock = udp_sock
|
||||||
self.remote_addr = remote_addr
|
self.remote_addr = remote_addr
|
||||||
self.send_buffer = deque() # 发送缓冲区
|
self.send_buffer = deque() # 发送缓冲区
|
||||||
@ -22,10 +22,12 @@ class ReliableChannel:
|
|||||||
self.window_size = 10 # 滑动窗口大小
|
self.window_size = 10 # 滑动窗口大小
|
||||||
self.retransmit_timers = {} # 重传计时器 {seq: send_time}
|
self.retransmit_timers = {} # 重传计时器 {seq: send_time}
|
||||||
self.retransmit_timeout = 0.5 # 重传超时时间 (秒)
|
self.retransmit_timeout = 0.5 # 重传超时时间 (秒)
|
||||||
|
self.conn_id = conn_id
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
def send(self, data, is_control=False):
|
def send(self, data, is_control=False):
|
||||||
"""发送数据并管理发送缓冲区"""
|
"""发送数据并管理发送缓冲区"""
|
||||||
|
print(f"发送数据: {data}")
|
||||||
if not self.running:
|
if not self.running:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -56,11 +58,13 @@ class ReliableChannel:
|
|||||||
for packet in list(self.send_buffer)[:self.window_size]:
|
for packet in list(self.send_buffer)[:self.window_size]:
|
||||||
self.udp_sock.sendto(json.dumps({
|
self.udp_sock.sendto(json.dumps({
|
||||||
'action': 'data',
|
'action': 'data',
|
||||||
'packet': packet
|
'packet': packet,
|
||||||
|
'conn_id': self.conn_id
|
||||||
}).encode(), self.remote_addr)
|
}).encode(), self.remote_addr)
|
||||||
|
|
||||||
def process_ack(self, ack_seq):
|
def process_ack(self, ack_seq):
|
||||||
"""处理接收到的ACK"""
|
"""处理接收到的ACK"""
|
||||||
|
print(f"收到ACK: {ack_seq}")
|
||||||
if ack_seq > self.last_ack:
|
if ack_seq > self.last_ack:
|
||||||
self.last_ack = ack_seq
|
self.last_ack = ack_seq
|
||||||
# 清除重传计时器
|
# 清除重传计时器
|
||||||
@ -98,8 +102,8 @@ class ReliableChannel:
|
|||||||
"""处理接收缓冲区中的连续数据"""
|
"""处理接收缓冲区中的连续数据"""
|
||||||
while self.expected_seq in self.recv_buffer:
|
while self.expected_seq in self.recv_buffer:
|
||||||
data = self.recv_buffer.pop(self.expected_seq)
|
data = self.recv_buffer.pop(self.expected_seq)
|
||||||
self.expected_seq += 1
|
# self.expected_seq += 1
|
||||||
self.recv_queue.append(data) # 添加到队列
|
# self.recv_queue.append(data) # 添加到队列
|
||||||
|
|
||||||
def get_received_data(self):
|
def get_received_data(self):
|
||||||
"""从接收队列中取出数据"""
|
"""从接收队列中取出数据"""
|
||||||
@ -110,7 +114,8 @@ class ReliableChannel:
|
|||||||
ack_packet = {
|
ack_packet = {
|
||||||
'action': 'ack',
|
'action': 'ack',
|
||||||
'ack_seq': self.expected_seq - 1, # 确认到目前收到的最大连续序列号
|
'ack_seq': self.expected_seq - 1, # 确认到目前收到的最大连续序列号
|
||||||
'window': self.window_size
|
'window': self.window_size,
|
||||||
|
'conn_id': self.conn_id
|
||||||
}
|
}
|
||||||
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
|
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
|
||||||
|
|
||||||
@ -124,7 +129,8 @@ class ReliableChannel:
|
|||||||
if packet['seq'] == seq:
|
if packet['seq'] == seq:
|
||||||
self.udp_sock.sendto(json.dumps({
|
self.udp_sock.sendto(json.dumps({
|
||||||
'action': 'data',
|
'action': 'data',
|
||||||
'packet': packet
|
'packet': packet,
|
||||||
|
'conn_id': self.conn_id
|
||||||
}).encode(), self.remote_addr)
|
}).encode(), self.remote_addr)
|
||||||
self.retransmit_timers[seq] = now
|
self.retransmit_timers[seq] = now
|
||||||
break
|
break
|
||||||
@ -181,6 +187,7 @@ class ServiceProvider:
|
|||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
data, addr = self.udp_sock.recvfrom(65535)
|
data, addr = self.udp_sock.recvfrom(65535)
|
||||||
|
print(f"收到UDP消息: {data.decode()}")
|
||||||
try:
|
try:
|
||||||
message = json.loads(data.decode())
|
message = json.loads(data.decode())
|
||||||
action = message.get('action')
|
action = message.get('action')
|
||||||
@ -216,6 +223,7 @@ class ServiceProvider:
|
|||||||
data_chunk = channel.process_packet(packet)
|
data_chunk = channel.process_packet(packet)
|
||||||
if data_chunk and conn_id in self.active_connections:
|
if data_chunk and conn_id in self.active_connections:
|
||||||
# 转发数据到本地服务
|
# 转发数据到本地服务
|
||||||
|
print(f"转发数据包到本地服务: {data_chunk[:10]}")
|
||||||
self.active_connections[conn_id]['local_sock'].sendall(data_chunk)
|
self.active_connections[conn_id]['local_sock'].sendall(data_chunk)
|
||||||
elif action == 'ack':
|
elif action == 'ack':
|
||||||
# 处理ACK确认
|
# 处理ACK确认
|
||||||
@ -249,7 +257,7 @@ class ServiceProvider:
|
|||||||
local_sock.connect(('127.0.0.1', self.internal_port))
|
local_sock.connect(('127.0.0.1', self.internal_port))
|
||||||
|
|
||||||
# 创建可靠通道
|
# 创建可靠通道
|
||||||
channel = ReliableChannel(self.udp_sock, client_addr)
|
channel = ReliableChannel(self.udp_sock, client_addr, conn_id)
|
||||||
self.reliable_channels[conn_id] = channel
|
self.reliable_channels[conn_id] = channel
|
||||||
|
|
||||||
# 存储连接
|
# 存储连接
|
||||||
@ -273,12 +281,6 @@ class ServiceProvider:
|
|||||||
daemon=True
|
daemon=True
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
threading.Thread(
|
|
||||||
target=self.forward_incoming_data,
|
|
||||||
args=(conn_id, channel, local_sock),
|
|
||||||
daemon=True
|
|
||||||
).start()
|
|
||||||
|
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.monitor_channel,
|
target=self.monitor_channel,
|
||||||
args=(conn_id, channel),
|
args=(conn_id, channel),
|
||||||
@ -294,14 +296,6 @@ class ServiceProvider:
|
|||||||
'message': str(e)
|
'message': str(e)
|
||||||
}).encode(), client_addr)
|
}).encode(), client_addr)
|
||||||
|
|
||||||
def forward_incoming_data(self, conn_id, channel, local_sock):
|
|
||||||
while conn_id in self.active_connections:
|
|
||||||
data = channel.get_received_data()
|
|
||||||
if data:
|
|
||||||
local_sock.sendall(data)
|
|
||||||
else:
|
|
||||||
time.sleep(0.02)
|
|
||||||
|
|
||||||
def forward_data(self, conn_id, local_sock, channel):
|
def forward_data(self, conn_id, local_sock, channel):
|
||||||
"""转发TCP数据到UDP通道"""
|
"""转发TCP数据到UDP通道"""
|
||||||
try:
|
try:
|
||||||
@ -387,7 +381,7 @@ class ServiceProvider:
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
COORDINATOR_ADDR = ('www.awin-x.top', 5000)
|
COORDINATOR_ADDR = ('www.awin-x.top', 5000)
|
||||||
SERVICE_NAME = "my_service"
|
SERVICE_NAME = "my_service"
|
||||||
INTERNAL_PORT = 5001
|
INTERNAL_PORT = 22
|
||||||
|
|
||||||
provider = ServiceProvider(COORDINATOR_ADDR, SERVICE_NAME, INTERNAL_PORT)
|
provider = ServiceProvider(COORDINATOR_ADDR, SERVICE_NAME, INTERNAL_PORT)
|
||||||
provider.run()
|
provider.run()
|
||||||
Loading…
Reference in New Issue
Block a user