AwinSimpleP2P/provider1.py
2025-05-31 17:52:01 +08:00

387 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import socket
import json
import threading
import time
import uuid
from collections import deque
class ReliableChannel:
"""可靠传输通道类实现序列号、ACK确认和重传"""
def __init__(self, udp_sock, remote_addr, conn_id):
self.udp_sock = udp_sock
self.remote_addr = remote_addr
self.send_buffer = deque() # 发送缓冲区
self.recv_buffer = {} # 接收缓冲区 (seq -> data)
self.recv_queue = deque() # 接收队列
self.expected_seq = 0 # 期望的下一个序列号
self.last_ack = -1 # 最后确认的序列号
self.ack_interval = 0.2 # ACK发送间隔 (秒)
self.last_ack_time = 0
self.window_size = 10 # 滑动窗口大小
self.retransmit_timers = {} # 重传计时器 {seq: send_time}
self.retransmit_timeout = 0.5 # 重传超时时间 (秒)
self.conn_id = conn_id
self.running = True
def send(self, data, is_control=False):
"""发送数据并管理发送缓冲区"""
print(f"发送数据: {data}")
if not self.running:
return
# 创建数据包
seq = self.expected_seq
packet = {
'seq': seq,
'data': data.hex(),
'is_control': is_control,
'timestamp': time.time()
}
# 添加到发送缓冲区
self.send_buffer.append(packet)
self.retransmit_timers[seq] = time.time()
self.expected_seq += 1
# 如果窗口有空闲,立即发送
self._send_from_buffer()
def _send_from_buffer(self):
"""从发送缓冲区发送数据包 (按照滑动窗口)"""
# 移除已确认的数据包
while self.send_buffer and self.send_buffer[0]['seq'] <= self.last_ack:
self.send_buffer.popleft()
# 发送窗口中的数据
for packet in list(self.send_buffer)[:self.window_size]:
self.udp_sock.sendto(json.dumps({
'action': 'data',
'packet': packet,
'conn_id': self.conn_id
}).encode(), self.remote_addr)
def process_ack(self, ack_seq):
"""处理接收到的ACK"""
print(f"收到ACK: {ack_seq}")
if ack_seq > self.last_ack:
self.last_ack = ack_seq
# 清除重传计时器
for seq in list(self.retransmit_timers.keys()):
if seq <= ack_seq:
self.retransmit_timers.pop(seq, None)
# 尝试发送更多数据
self._send_from_buffer()
def process_packet(self, packet):
"""处理接收到的数据包"""
seq = packet['seq']
data = bytes.fromhex(packet['data'])
# 发送ACK
if time.time() - self.last_ack_time > self.ack_interval:
self.send_ack()
self.last_ack_time = time.time()
# 按序接收
if seq == self.expected_seq:
self.expected_seq += 1
# 处理之前缓存的数据
self._process_buffered()
return data
# 乱序接收,缓存数据
elif seq > self.expected_seq:
self.recv_buffer[seq] = data
return None
# 重复数据包
else:
return None
def _process_buffered(self):
"""处理接收缓冲区中的连续数据"""
while self.expected_seq in self.recv_buffer:
data = self.recv_buffer.pop(self.expected_seq)
# self.expected_seq += 1
# self.recv_queue.append(data) # 添加到队列
def get_received_data(self):
"""从接收队列中取出数据"""
return self.recv_queue.popleft() if self.recv_queue else None
def send_ack(self):
"""发送ACK确认"""
ack_packet = {
'action': 'ack',
'ack_seq': self.expected_seq - 1, # 确认到目前收到的最大连续序列号
'window': self.window_size,
'conn_id': self.conn_id
}
self.udp_sock.sendto(json.dumps(ack_packet).encode(), self.remote_addr)
def check_retransmit(self):
"""检查需要重传的数据包"""
now = time.time()
for seq, send_time in list(self.retransmit_timers.items()):
if now - send_time > self.retransmit_timeout:
# 重传数据包
for packet in self.send_buffer:
if packet['seq'] == seq:
self.udp_sock.sendto(json.dumps({
'action': 'data',
'packet': packet,
'conn_id': self.conn_id
}).encode(), self.remote_addr)
self.retransmit_timers[seq] = now
break
def close(self):
"""关闭通道"""
self.running = False
self.send_buffer.clear()
self.recv_buffer.clear()
self.retransmit_timers.clear()
class ServiceProvider:
def __init__(self, coordinator_addr, service_name, internal_port):
self.coordinator_addr = coordinator_addr
self.service_name = service_name
self.internal_port = internal_port
self.client_id = f"provider-{uuid.uuid4().hex[:8]}"
# 创建UDP套接字
self.udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.udp_sock.bind(('0.0.0.0', 0))
self.udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024) # 1MB接收缓冲区
self.udp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.punch_port = self.udp_sock.getsockname()[1]
# 存储连接和可靠通道
self.active_connections = {}
self.reliable_channels = {} # conn_id -> ReliableChannel
self.running = True
def register_service(self):
"""向协调服务器注册服务"""
message = {
'action': 'register',
'service_name': self.service_name,
'internal_port': self.internal_port,
'external_port': self.udp_sock.getsockname()[1],
'client_id': self.client_id
}
print(f"向协调服务器 {self.coordinator_addr} 注册服务 '{self.service_name}'")
self.udp_sock.sendto(json.dumps(message).encode(), self.coordinator_addr)
# 等待响应
data, _ = self.udp_sock.recvfrom(4096)
response = json.loads(data.decode())
if response['status'] == 'success':
print(f"服务 '{self.service_name}' 注册成功")
else:
print(f"注册失败: {response['message']}")
def udp_listener(self):
"""监听UDP消息"""
while self.running:
try:
data, addr = self.udp_sock.recvfrom(65535)
print(f"收到UDP消息: {data.decode()}")
try:
message = json.loads(data.decode())
action = message.get('action')
if action == 'punch' or action == 'punch_check':
# 打洞请求 - 发送响应
self.udp_sock.sendto(json.dumps(
{'action': 'punch_response', 'client_id': self.client_id}
).encode(), addr)
elif action == 'punch_response':
# 打洞响应
pass
elif action == 'connect':
# 新的连接请求
conn_id = message['conn_id']
client_id = message['client_id']
threading.Thread(
target=self.handle_connection,
args=(conn_id, addr, client_id),
daemon=True
).start()
elif action == 'punch_request':
# 从协调服务器收到的打洞请求
self.punch_hole(message)
elif action == 'data':
# 处理数据包
conn_id = message.get('conn_id')
if conn_id and conn_id in self.reliable_channels:
channel = self.reliable_channels[conn_id]
packet = message['packet']
# 处理数据包
data_chunk = channel.process_packet(packet)
if data_chunk and conn_id in self.active_connections:
# 转发数据到本地服务
print(f"转发数据包到本地服务: {data_chunk[:10]}")
self.active_connections[conn_id]['local_sock'].sendall(data_chunk)
elif action == 'ack':
# 处理ACK确认
conn_id = message.get('conn_id')
if conn_id and conn_id in self.reliable_channels:
channel = self.reliable_channels[conn_id]
ack_seq = message['ack_seq']
channel.process_ack(ack_seq)
elif action == 'stop_conn':
conn_id = message['conn_id']
self.close_connection(conn_id)
elif action == 'stop_client':
client_id = message['client_id']
self.close_client_connections(client_id)
except json.JSONDecodeError:
# 非JSON数据可能是打洞包
pass
except Exception as e:
print(f"UDP监听错误: {str(e)}")
except socket.timeout:
pass
except Exception as e:
print(f"UDP监听异常: {str(e)}")
def handle_connection(self, conn_id, client_addr, client_id):
"""处理来自客户端的连接"""
try:
# 连接到本地服务
local_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
local_sock.settimeout(5.0)
local_sock.connect(('127.0.0.1', self.internal_port))
# 创建可靠通道
channel = ReliableChannel(self.udp_sock, client_addr, conn_id)
self.reliable_channels[conn_id] = channel
# 存储连接
self.active_connections[conn_id] = {
'local_sock': local_sock,
'client_addr': client_addr,
'client_id': client_id,
'channel': channel
}
# 通知客户端连接就绪
self.udp_sock.sendto(json.dumps({
'action': 'connected',
'conn_id': conn_id
}).encode(), client_addr)
# 启动数据转发和通道监控
threading.Thread(
target=self.forward_data,
args=(conn_id, local_sock, channel),
daemon=True
).start()
threading.Thread(
target=self.monitor_channel,
args=(conn_id, channel),
daemon=True
).start()
print(f"建立连接 {conn_id}: {client_addr} -> 本地服务")
except Exception as e:
print(f"连接失败: {str(e)}")
self.udp_sock.sendto(json.dumps({
'action': 'connect_failed',
'conn_id': conn_id,
'message': str(e)
}).encode(), client_addr)
def forward_data(self, conn_id, local_sock, channel):
"""转发TCP数据到UDP通道"""
try:
while conn_id in self.active_connections:
# 从本地服务读取数据
data = local_sock.recv(32768) # 32KB 数据块
if not data:
break
# 通过可靠通道发送
channel.send(data)
except socket.timeout:
pass
except Exception as e:
print(f"数据转发错误: {str(e)}")
finally:
self.close_connection(conn_id)
def monitor_channel(self, conn_id, channel):
"""监控通道状态和重传"""
while conn_id in self.active_connections:
try:
channel.check_retransmit()
# 发送定期ACK
if time.time() - channel.last_ack_time > channel.ack_interval:
channel.send_ack()
channel.last_ack_time = time.time()
time.sleep(0.1)
except Exception as e:
print(f"通道监控错误: {str(e)}")
def close_connection(self, conn_id):
"""关闭指定连接"""
if conn_id in self.active_connections:
conn = self.active_connections.pop(conn_id)
conn['local_sock'].close()
if conn_id in self.reliable_channels:
channel = self.reliable_channels.pop(conn_id)
channel.close()
print(f"连接 {conn_id} 已关闭")
def close_client_connections(self, client_id):
"""关闭指定客户端的连接"""
for conn_id, info in list(self.active_connections.items()):
if info['client_id'] == client_id:
self.close_connection(conn_id)
def run(self):
"""运行服务提供端"""
self.register_service()
threading.Thread(target=self.udp_listener, daemon=True).start()
try:
while self.running:
time.sleep(1)
except KeyboardInterrupt:
self.shutdown()
def punch_hole(self, message):
"""执行打洞操作"""
for i in range(5):
try:
self.udp_sock.sendto(json.dumps({
'action': 'punch',
'client_id': self.client_id
}).encode(), tuple(message['client_addr']))
time.sleep(0.2)
except Exception as e:
print(f"打洞失败: {str(e)}")
def shutdown(self):
"""关闭服务提供端"""
self.running = False
# 通知协调服务器停止服务
self.udp_sock.sendto(json.dumps({'action': 'stop_provider'}).encode(), self.coordinator_addr)
# 关闭所有连接
for conn_id in list(self.active_connections.keys()):
self.close_connection(conn_id)
self.udp_sock.close()
print("服务提供端已停止")
if __name__ == '__main__':
COORDINATOR_ADDR = ('www.awin-x.top', 5000)
SERVICE_NAME = "my_service"
INTERNAL_PORT = 22
provider = ServiceProvider(COORDINATOR_ADDR, SERVICE_NAME, INTERNAL_PORT)
provider.run()