387 lines
14 KiB
Python
387 lines
14 KiB
Python
import socket
|
||
import json
|
||
import threading
|
||
import time
|
||
import uuid
|
||
from collections import deque
|
||
|
||
|
||
class ReliableChannel:
|
||
"""可靠传输通道类,实现序列号、ACK确认和重传"""
|
||
|
||
def __init__(self, udp_sock, remote_addr, conn_id):
|
||
self.udp_sock = udp_sock
|
||
self.remote_addr = remote_addr
|
||
self.send_buffer = deque() # 发送缓冲区
|
||
self.recv_buffer = {} # 接收缓冲区 (seq -> data)
|
||
self.recv_queue = deque() # 接收队列
|
||
self.expected_seq = 0 # 期望的下一个序列号
|
||
self.last_ack = -1 # 最后确认的序列号
|
||
self.ack_interval = 0.2 # ACK发送间隔 (秒)
|
||
self.last_ack_time = 0
|
||
self.window_size = 10 # 滑动窗口大小
|
||
self.retransmit_timers = {} # 重传计时器 {seq: send_time}
|
||
self.retransmit_timeout = 0.5 # 重传超时时间 (秒)
|
||
self.conn_id = conn_id
|
||
self.running = True
|
||
|
||
def send(self, data, is_control=False):
|
||
"""发送数据并管理发送缓冲区"""
|
||
print(f"发送数据: {data}")
|
||
if not self.running:
|
||
return
|
||
|
||
# 创建数据包
|
||
seq = self.expected_seq
|
||
packet = {
|
||
'seq': seq,
|
||
'data': data.hex(),
|
||
'is_control': is_control,
|
||
'timestamp': time.time()
|
||
}
|
||
|
||
# 添加到发送缓冲区
|
||
self.send_buffer.append(packet)
|
||
self.retransmit_timers[seq] = time.time()
|
||
self.expected_seq += 1
|
||
|
||
# 如果窗口有空闲,立即发送
|
||
self._send_from_buffer()
|
||
|
||
def _send_from_buffer(self):
|
||
"""从发送缓冲区发送数据包 (按照滑动窗口)"""
|
||
# 移除已确认的数据包
|
||
while self.send_buffer and self.send_buffer[0]['seq'] <= self.last_ack:
|
||
self.send_buffer.popleft()
|
||
|
||
# 发送窗口中的数据
|
||
for packet in list(self.send_buffer)[:self.window_size]:
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'data',
|
||
'packet': packet,
|
||
'conn_id': self.conn_id
|
||
}).encode(), self.remote_addr)
|
||
|
||
def process_ack(self, ack_seq):
|
||
"""处理接收到的ACK"""
|
||
print(f"收到ACK: {ack_seq}")
|
||
if ack_seq > self.last_ack:
|
||
self.last_ack = ack_seq
|
||
# 清除重传计时器
|
||
for seq in list(self.retransmit_timers.keys()):
|
||
if seq <= ack_seq:
|
||
self.retransmit_timers.pop(seq, None)
|
||
# 尝试发送更多数据
|
||
self._send_from_buffer()
|
||
|
||
def process_packet(self, packet):
|
||
"""处理接收到的数据包"""
|
||
seq = packet['seq']
|
||
data = bytes.fromhex(packet['data'])
|
||
|
||
# 发送ACK
|
||
if time.time() - self.last_ack_time > self.ack_interval:
|
||
self.send_ack()
|
||
self.last_ack_time = time.time()
|
||
|
||
# 按序接收
|
||
if seq == self.expected_seq:
|
||
self.expected_seq += 1
|
||
# 处理之前缓存的数据
|
||
self._process_buffered()
|
||
return data
|
||
# 乱序接收,缓存数据
|
||
elif seq > self.expected_seq:
|
||
self.recv_buffer[seq] = data
|
||
return None
|
||
# 重复数据包
|
||
else:
|
||
return None
|
||
|
||
def _process_buffered(self):
|
||
"""处理接收缓冲区中的连续数据"""
|
||
while self.expected_seq in self.recv_buffer:
|
||
data = self.recv_buffer.pop(self.expected_seq)
|
||
# self.expected_seq += 1
|
||
# self.recv_queue.append(data) # 添加到队列
|
||
|
||
def get_received_data(self):
|
||
"""从接收队列中取出数据"""
|
||
return self.recv_queue.popleft() if self.recv_queue else None
|
||
|
||
def send_ack(self):
|
||
"""发送ACK确认"""
|
||
ack_packet = {
|
||
'action': 'ack',
|
||
'ack_seq': self.expected_seq - 1, # 确认到目前收到的最大连续序列号
|
||
'window': self.window_size,
|
||
'conn_id': self.conn_id
|
||
}
|
||
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
|
||
|
||
def check_retransmit(self):
|
||
"""检查需要重传的数据包"""
|
||
now = time.time()
|
||
for seq, send_time in list(self.retransmit_timers.items()):
|
||
if now - send_time > self.retransmit_timeout:
|
||
# 重传数据包
|
||
for packet in self.send_buffer:
|
||
if packet['seq'] == seq:
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'data',
|
||
'packet': packet,
|
||
'conn_id': self.conn_id
|
||
}).encode(), self.remote_addr)
|
||
self.retransmit_timers[seq] = now
|
||
break
|
||
|
||
def close(self):
|
||
"""关闭通道"""
|
||
self.running = False
|
||
self.send_buffer.clear()
|
||
self.recv_buffer.clear()
|
||
self.retransmit_timers.clear()
|
||
|
||
|
||
class ServiceProvider:
|
||
def __init__(self, coordinator_addr, service_name, internal_port):
|
||
self.coordinator_addr = coordinator_addr
|
||
self.service_name = service_name
|
||
self.internal_port = internal_port
|
||
self.client_id = f"provider-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 创建UDP套接字
|
||
self.udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||
self.udp_sock.bind(('0.0.0.0', 0))
|
||
self.udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024) # 1MB接收缓冲区
|
||
self.udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||
self.punch_port = self.udp_sock.getsockname()[1]
|
||
|
||
# 存储连接和可靠通道
|
||
self.active_connections = {}
|
||
self.reliable_channels = {} # conn_id -> ReliableChannel
|
||
self.running = True
|
||
|
||
def register_service(self):
|
||
"""向协调服务器注册服务"""
|
||
message = {
|
||
'action': 'register',
|
||
'service_name': self.service_name,
|
||
'internal_port': self.internal_port,
|
||
'external_port': self.udp_sock.getsockname()[1],
|
||
'client_id': self.client_id
|
||
}
|
||
print(f"向协调服务器 {self.coordinator_addr} 注册服务 '{self.service_name}'")
|
||
self.udp_sock.sendto(json.dumps(message).encode(), self.coordinator_addr)
|
||
|
||
# 等待响应
|
||
data, _ = self.udp_sock.recvfrom(4096)
|
||
response = json.loads(data.decode())
|
||
if response['status'] == 'success':
|
||
print(f"服务 '{self.service_name}' 注册成功")
|
||
else:
|
||
print(f"注册失败: {response['message']}")
|
||
|
||
def udp_listener(self):
|
||
"""监听UDP消息"""
|
||
while self.running:
|
||
try:
|
||
data, addr = self.udp_sock.recvfrom(65535)
|
||
print(f"收到UDP消息: {data.decode()}")
|
||
try:
|
||
message = json.loads(data.decode())
|
||
action = message.get('action')
|
||
|
||
if action == 'punch' or action == 'punch_check':
|
||
# 打洞请求 - 发送响应
|
||
self.udp_sock.sendto(json.dumps(
|
||
{'action': 'punch_response', 'client_id': self.client_id}
|
||
).encode(), addr)
|
||
elif action == 'punch_response':
|
||
# 打洞响应
|
||
pass
|
||
elif action == 'connect':
|
||
# 新的连接请求
|
||
conn_id = message['conn_id']
|
||
client_id = message['client_id']
|
||
threading.Thread(
|
||
target=self.handle_connection,
|
||
args=(conn_id, addr, client_id),
|
||
daemon=True
|
||
).start()
|
||
elif action == 'punch_request':
|
||
# 从协调服务器收到的打洞请求
|
||
self.punch_hole(message)
|
||
elif action == 'data':
|
||
# 处理数据包
|
||
conn_id = message.get('conn_id')
|
||
if conn_id and conn_id in self.reliable_channels:
|
||
channel = self.reliable_channels[conn_id]
|
||
packet = message['packet']
|
||
|
||
# 处理数据包
|
||
data_chunk = channel.process_packet(packet)
|
||
if data_chunk and conn_id in self.active_connections:
|
||
# 转发数据到本地服务
|
||
print(f"转发数据包到本地服务: {data_chunk[:10]}")
|
||
self.active_connections[conn_id]['local_sock'].sendall(data_chunk)
|
||
elif action == 'ack':
|
||
# 处理ACK确认
|
||
conn_id = message.get('conn_id')
|
||
if conn_id and conn_id in self.reliable_channels:
|
||
channel = self.reliable_channels[conn_id]
|
||
ack_seq = message['ack_seq']
|
||
channel.process_ack(ack_seq)
|
||
elif action == 'stop_conn':
|
||
conn_id = message['conn_id']
|
||
self.close_connection(conn_id)
|
||
elif action == 'stop_client':
|
||
client_id = message['client_id']
|
||
self.close_client_connections(client_id)
|
||
except json.JSONDecodeError:
|
||
# 非JSON数据可能是打洞包
|
||
pass
|
||
except Exception as e:
|
||
print(f"UDP监听错误: {str(e)}")
|
||
except socket.timeout:
|
||
pass
|
||
except Exception as e:
|
||
print(f"UDP监听异常: {str(e)}")
|
||
|
||
def handle_connection(self, conn_id, client_addr, client_id):
|
||
"""处理来自客户端的连接"""
|
||
try:
|
||
# 连接到本地服务
|
||
local_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
local_sock.settimeout(5.0)
|
||
local_sock.connect(('127.0.0.1', self.internal_port))
|
||
|
||
# 创建可靠通道
|
||
channel = ReliableChannel(self.udp_sock, client_addr, conn_id)
|
||
self.reliable_channels[conn_id] = channel
|
||
|
||
# 存储连接
|
||
self.active_connections[conn_id] = {
|
||
'local_sock': local_sock,
|
||
'client_addr': client_addr,
|
||
'client_id': client_id,
|
||
'channel': channel
|
||
}
|
||
|
||
# 通知客户端连接就绪
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'connected',
|
||
'conn_id': conn_id
|
||
}).encode(), client_addr)
|
||
|
||
# 启动数据转发和通道监控
|
||
threading.Thread(
|
||
target=self.forward_data,
|
||
args=(conn_id, local_sock, channel),
|
||
daemon=True
|
||
).start()
|
||
|
||
threading.Thread(
|
||
target=self.monitor_channel,
|
||
args=(conn_id, channel),
|
||
daemon=True
|
||
).start()
|
||
|
||
print(f"建立连接 {conn_id}: {client_addr} -> 本地服务")
|
||
except Exception as e:
|
||
print(f"连接失败: {str(e)}")
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'connect_failed',
|
||
'conn_id': conn_id,
|
||
'message': str(e)
|
||
}).encode(), client_addr)
|
||
|
||
def forward_data(self, conn_id, local_sock, channel):
|
||
"""转发TCP数据到UDP通道"""
|
||
try:
|
||
while conn_id in self.active_connections:
|
||
# 从本地服务读取数据
|
||
data = local_sock.recv(32768) # 32KB 数据块
|
||
if not data:
|
||
break
|
||
|
||
# 通过可靠通道发送
|
||
channel.send(data)
|
||
except socket.timeout:
|
||
pass
|
||
except Exception as e:
|
||
print(f"数据转发错误: {str(e)}")
|
||
finally:
|
||
self.close_connection(conn_id)
|
||
|
||
def monitor_channel(self, conn_id, channel):
|
||
"""监控通道状态和重传"""
|
||
while conn_id in self.active_connections:
|
||
try:
|
||
channel.check_retransmit()
|
||
# 发送定期ACK
|
||
if time.time() - channel.last_ack_time > channel.ack_interval:
|
||
channel.send_ack()
|
||
channel.last_ack_time = time.time()
|
||
time.sleep(0.1)
|
||
except Exception as e:
|
||
print(f"通道监控错误: {str(e)}")
|
||
|
||
def close_connection(self, conn_id):
|
||
"""关闭指定连接"""
|
||
if conn_id in self.active_connections:
|
||
conn = self.active_connections.pop(conn_id)
|
||
conn['local_sock'].close()
|
||
if conn_id in self.reliable_channels:
|
||
channel = self.reliable_channels.pop(conn_id)
|
||
channel.close()
|
||
print(f"连接 {conn_id} 已关闭")
|
||
|
||
def close_client_connections(self, client_id):
|
||
"""关闭指定客户端的连接"""
|
||
for conn_id, info in list(self.active_connections.items()):
|
||
if info['client_id'] == client_id:
|
||
self.close_connection(conn_id)
|
||
|
||
def run(self):
|
||
"""运行服务提供端"""
|
||
self.register_service()
|
||
threading.Thread(target=self.udp_listener, daemon=True).start()
|
||
|
||
try:
|
||
while self.running:
|
||
time.sleep(1)
|
||
except KeyboardInterrupt:
|
||
self.shutdown()
|
||
|
||
def punch_hole(self, message):
|
||
"""执行打洞操作"""
|
||
for i in range(5):
|
||
try:
|
||
self.udp_sock.sendto(json.dumps({
|
||
'action': 'punch',
|
||
'client_id': self.client_id
|
||
}).encode(), tuple(message['client_addr']))
|
||
time.sleep(0.2)
|
||
except Exception as e:
|
||
print(f"打洞失败: {str(e)}")
|
||
|
||
def shutdown(self):
|
||
"""关闭服务提供端"""
|
||
self.running = False
|
||
# 通知协调服务器停止服务
|
||
self.udp_sock.sendto(json.dumps({'action': 'stop_provider'}).encode(), self.coordinator_addr)
|
||
# 关闭所有连接
|
||
for conn_id in list(self.active_connections.keys()):
|
||
self.close_connection(conn_id)
|
||
self.udp_sock.close()
|
||
print("服务提供端已停止")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
COORDINATOR_ADDR = ('www.awin-x.top', 5000)
|
||
SERVICE_NAME = "my_service"
|
||
INTERNAL_PORT = 22
|
||
|
||
provider = ServiceProvider(COORDINATOR_ADDR, SERVICE_NAME, INTERNAL_PORT)
|
||
provider.run() |