#include "stdafx.h"

#include "UdpHolepunching.h"



using namespace RLib;
using namespace boost;
using namespace boost::asio;



/////////////////////////////////////////////////////////////////////////

class CUdpSocket::CSocket
{
public:
	CUdpSocket		*m_pUdpSocket;
	ip::udp::socket	m_socket;
public:
	CSocket(CUdpSocket &udpSocket,io_service &ioService)
		:m_pUdpSocket(&udpSocket)
		,m_socket(ioService)
		{
			m_socket.open(ip::udp::v4());
		}
};

CUdpSocket::CUdpSocket(asio::io_service &ioService)
:m_spSocket(new CSocket(*this,ioService))
{
}

CUdpSocket::~CUdpSocket()
{
	m_spSocket->m_pUdpSocket = NULL;		// jꂽ}[N
}

void CUdpSocket::Close()
{
	m_spSocket->m_socket.close();
}

bool CUdpSocket::Bind(unsigned short nPort,boost::system::error_code &ec)
{
	return m_spSocket->m_socket.bind( ip::udp::endpoint(ip::udp::v4(),nPort), ec ) == false;	// G[Ȃ
}

bool CUdpSocket::Receive(const FuncOnReceived &funcOnReceived,unsigned short nBufferSize)
{
	struct F{
		static void OnReceived(const boost::system::error_code &ec,size_t bytesReceived,boost::shared_ptr<CSocket> spSocket,boost::shared_ptr<ip::udp::endpoint> spEndpointRemote,boost::shared_ptr<vector<char>> spBuffer,const FuncOnReceived funcOnReceived)
		{
			if( !spSocket->m_pUdpSocket ) return;						// jĂ?
			if( ec ){	// Error
				if( ec == boost::asio::error::operation_aborted ) return;	// Xbh̏I܂̓AvP[V̗vɂāAI/O ͒~܂B
				//ATLTRACE( _T("\nCUdpSocket Error OnRecived %s -> %s"), CUdpSocket::GetTextAddress(*spEndpointRemote), CString(ec.message().c_str()) );
				//return; G[łR[͂
			}

			spBuffer->resize(bytesReceived);
			if( funcOnReceived ){
				funcOnReceived( *spSocket->m_pUdpSocket, ec, *spEndpointRemote, spBuffer );
			}
		}
	};

	boost::shared_ptr<ip::udp::endpoint> spEndpoint(new ip::udp::endpoint);
	boost::shared_ptr<vector<char>> spBuffer(new vector<char>(nBufferSize));	// Mobt@
	m_spSocket->m_socket.async_receive_from(
		asio::buffer(*spBuffer),
		*spEndpoint,
		boost::bind(
			&F::OnReceived,
			boost::asio::placeholders::error,
			boost::asio::placeholders::bytes_transferred,
			m_spSocket,spEndpoint,spBuffer,funcOnReceived)
	);

	return true;
}

bool CUdpSocket::Connect(const std::string &sDomain,const std::string &sPort)
{
	ip::udp::resolver resolver(m_spSocket->m_socket.get_io_service());
	ip::udp::resolver::query query(ip::udp::v4(), sDomain, sPort );
	ip::udp::resolver::iterator iterator = resolver.resolve(query);
	boost::system::error_code ec;
	do{
		if( m_spSocket->m_socket.connect(*iterator,ec) ){					// rŏ܂Ȃ悤ɓ֐g
			if( ++iterator == asio::ip::udp::resolver::iterator() ) break;	// ȂΔ
		}
	}while(ec);
	if( ec ){								// ڑsȂ
		ATLTRACE( _T("\nCUdpSocket Error Connect -> %s"), CString(ec.message().c_str()) );
		return false;
	}
	return true;
}

bool CUdpSocket::SendTo(const boost::asio::ip::udp::endpoint &endPoint,const boost::shared_ptr<const vector<char>> &spData,boost::system::error_code &ec)
{
	if( !spData ){
		BOOST_ASSERT(false);
		return false;
	}
	const size_t size = m_spSocket->m_socket.send_to( asio::buffer(*spData), endPoint, 0, ec );
	if( !ec ) return size == spData->size();
	ATLTRACE( _T("\nCUdpSocket Error Send -> %s"), CString(ec.message().c_str()) );
	return false;
}

bool CUdpSocket::Send(const boost::shared_ptr<const vector<char>> &spData,boost::system::error_code &ec)
{
	if( !spData ){
		BOOST_ASSERT(false);
		return false;
	}
	const size_t size = m_spSocket->m_socket.send( asio::buffer(*spData), 0, ec );
	if( !ec ) return size == spData->size();
	ATLTRACE( _T("\nCUdpSocket Error Send -> %s"), CString(ec.message().c_str()) );
	return false;
}

////////////////////////////////////////////////


class CUdpHolepunching::CMain
	:public CWindowImpl<CMain>
{
public:
	DECLARE_WND_CLASS( _T("CMain") );
	BEGIN_MSG_MAP(CMain)
		MESSAGE_HANDLER(WM_TIMER, OnTimer)
	END_MSG_MAP()
public:
	LRESULT OnTimer(UINT, WPARAM, LPARAM, BOOL&)
		{
			m_ioService.poll();
			return 0;
		}
public:
	const CRUid					m_id;
	FuncMessage					m_funcMessage;
	io_service					m_ioService;
	CUdpSocket					m_udpSocket;
	auto_ptr<deadline_timer>	m_apTimer;

#pragma pack(push,1)
	struct CEndpoint
	{
		unsigned long	m_nIpv4;
		unsigned short	m_nPort;
		CEndpoint()
			{}
		CEndpoint(const ip::udp::endpoint &endPoint)
			{
				m_nIpv4 = endPoint.address().to_v4().to_ulong();
				m_nPort = endPoint.port();
			}
	};
#pragma pack(pop)

	std::map<CRUid,CEndpoint>	m_mapMember;

	void OnReceived(CUdpSocket &udpSocket,const boost::system::error_code &ec,const boost::asio::ip::udp::endpoint &endPointRemote,const boost::shared_ptr<const vector<char>> &spBuffer)
		{
			if( ec ){	// Error
				m_funcMessage( CRString::Format(_T("%s : OnRecived Error [%s] %s"), CString(m_id.GetText().c_str()), CUdpSocket::GetTextAddress(endPointRemote), CString(ec.message().c_str())) );
			}else{
				m_funcMessage( CRString::Format(_T("%s : OnRecived [%s]"), CString(m_id.GetText().c_str()), CUdpSocket::GetTextAddress(endPointRemote) ) );

				// cXgXV
				vector<char> vBuffer(*spBuffer);
				if( vBuffer.size() >= sizeof(CRUid) ){						// ̃\Pbg̑g̃cXgɃ}[W
					CRUid id;
					std::copy( vBuffer.begin(), vBuffer.begin()+sizeof(id), stdext::checked_array_iterator<char*>(reinterpret_cast<char*>(&id),sizeof(id)) );	// ::memcpy(&id, &vBuffer[0], sizeof(id) );
					vBuffer.erase( vBuffer.begin(), vBuffer.begin()+sizeof(id) );
					m_mapMember[id] = CEndpoint(endPointRemote);
				}else BOOST_ASSERT(false);
				while( vBuffer.size() >= sizeof(CRUid)+sizeof(CEndpoint) ){	// 傩̃cXgg̃cXgɃ}[W
					CRUid id;
					std::copy( vBuffer.begin(), vBuffer.begin()+sizeof(id), stdext::checked_array_iterator<char*>(reinterpret_cast<char*>(&id),sizeof(id)) );	// ::memcpy(&id, &vBuffer[0], sizeof(id) );
					vBuffer.erase( vBuffer.begin(), vBuffer.begin()+sizeof(id) );
					CEndpoint ep;
					std::copy( vBuffer.begin(), vBuffer.begin()+sizeof(ep), stdext::checked_array_iterator<char*>(reinterpret_cast<char*>(&ep),sizeof(ep)) );	// ::memcpy(&ep, &vBuffer[0], sizeof(ep) );
					vBuffer.erase( vBuffer.begin(), vBuffer.begin()+sizeof(ep) );
					if( id != m_id ){										// g͏
						m_mapMember[id] = ep;
					}
				}

			}

			udpSocket.Receive( boost::bind(&CMain::OnReceived,this,_1,_2,_3,_4) );		// ̓ǂݍݑ҂Jn
		}

public:
	CMain(const FuncMessage &funcMessage)
		:m_id(CRUid::RandomGenerator())
		,m_funcMessage(funcMessage)
		,m_udpSocket(m_ioService)
		,m_apTimer(new deadline_timer(m_ioService))
		{
			Create(HWND_MESSAGE,CRect(0,0,0,0),_T(""),NULL);
			::SetTimer( m_hWnd, 0, 1, NULL );

			m_apTimer->expires_from_now( boost::posix_time::millisec(0) );			// nTimeb
			m_apTimer->async_wait( boost::bind( &CMain::OnDeadlineTimer, this, _1 ) );
		}

	~CMain()
		{
			m_apTimer.reset();
			m_udpSocket.Close();

			if( m_hWnd ) DestroyWindow();
		}

	void OnDeadlineTimer(const boost::system::error_code& error)
		{
			if( error || error==boost::asio::error::operation_aborted ) return;	// `FbN

			if( m_apTimer.get() ){						// IĂȂȂ玟̃^C}[ݒ
				m_apTimer->expires_at( m_apTimer->expires_at() + boost::posix_time::millisec(5000) );
				m_apTimer->async_wait( boost::bind( &CMain::OnDeadlineTimer, this, _1 ) );
			}

			// cXgM
			boost::shared_ptr<vector<char>> spData(new vector<char>);
			std::copy( reinterpret_cast<const char*>(&m_id), reinterpret_cast<const char*>(&m_id)+sizeof(m_id), std::back_inserter(*spData) );
			for(std::map<CRUid,CEndpoint>::const_iterator i=m_mapMember.begin(); i!=m_mapMember.end(); i++ ){
				const CRUid &id = i->first;
				std::copy( reinterpret_cast<const char*>(&id), reinterpret_cast<const char*>(&id)+sizeof(id), std::back_inserter(*spData) );
				const CEndpoint &ep = i->second;
				std::copy( reinterpret_cast<const char*>(&ep), reinterpret_cast<const char*>(&ep)+sizeof(ep), std::back_inserter(*spData) );
			}

			CString sSendTo;
			for(std::map<CRUid,CEndpoint>::const_iterator i=m_mapMember.begin(); i!=m_mapMember.end(); i++ ){
				const CEndpoint &ep = i->second;
				ip::address addr(boost::asio::ip::address_v4(ep.m_nIpv4));
				ip::udp::endpoint endpoint( addr, ep.m_nPort );
				m_udpSocket.SendTo( endpoint, spData );
				sSendTo.Format(_T("%s [%s]"),CString(sSendTo),CUdpSocket::GetTextAddress(endpoint));
			}
			if( !sSendTo.IsEmpty() ) m_funcMessage( CRString::Format(_T("%s : SendTo %s"), CString(m_id.GetText().c_str()), sSendTo ) );
		}

	void Server(unsigned short nPort)
		{
			boost::system::error_code ec;
			if( !m_udpSocket.Bind(nPort,ec) ){
				m_funcMessage( CRString::Format(_T("%s : Bind Error [%s]"), CString(m_id.GetText().c_str()), CString(ec.message().c_str()) ) );
				return;
			}
			m_udpSocket.Receive( boost::bind(&CMain::OnReceived,this,_1,_2,_3,_4) );		// ǂݍݑ҂Jn
		}

	void Client(const std::string &sDomain,const std::string &sPort)
		{
			boost::shared_ptr<vector<char>> spData(new vector<char>);
			std::copy( reinterpret_cast<const char*>(&m_id), reinterpret_cast<const char*>(&m_id)+sizeof(m_id), std::back_inserter(*spData) );

			ip::udp::resolver resolver(m_ioService);
			ip::udp::resolver::query query(ip::udp::v4(), sDomain, sPort );
			boost::system::error_code ec;
			ip::udp::resolver::iterator iterator = resolver.resolve(query,ec);
			if( ec ){
				m_funcMessage( CRString::Format(_T("%s : resolve Error [%s]"), CString(m_id.GetText().c_str()), CString(ec.message().c_str()) ) );
				return;
			}
			m_udpSocket.SendTo( *iterator,spData );
			m_udpSocket.Receive( boost::bind(&CMain::OnReceived,this,_1,_2,_3,_4) );		// ǂݍݑ҂Jn
		}

};


CUdpHolepunching::CUdpHolepunching(const FuncMessage &funcMessage)
:m_main(*new CMain(funcMessage))
{
}

CUdpHolepunching::~CUdpHolepunching()
{
	delete &m_main;
}

void CUdpHolepunching::Server(unsigned short nPort)
{
	m_main.m_ioService.post( boost::bind( &CMain::Server, &m_main, nPort ) );
}

void CUdpHolepunching::Client(const std::string &sDomain,const std::string &sPort)
{
	m_main.m_ioService.post( boost::bind( &CMain::Client, &m_main, sDomain, sPort ) );
}

