-
Notifications
You must be signed in to change notification settings - Fork 1
/
TcpStream.cpp
131 lines (123 loc) · 3.75 KB
/
TcpStream.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include "TcpStream.hpp"
#include <cstring>
#include <iostream>
#include <stdexcept>
#include <string>
#ifdef _WIN32
typedef int RECV_SEND_T;
#else
typedef ssize_t RECV_SEND_T;
#endif
TcpStream::TcpStream(const std::string &host, int port) {
#ifdef _WIN32
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(1, 1), &wsa_data) != 0) {
throw std::runtime_error("Failed to initialize sockets");
}
#endif
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == -1) {
throw std::runtime_error("Failed to create socket");
}
int yes = 1;
if (setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *)&yes, sizeof(int)) <
0) {
throw std::runtime_error("Failed to set TCP_NODELAY");
}
addrinfo hints, *servinfo;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
if (getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints,
&servinfo) != 0) {
throw std::runtime_error("Failed to get addr info");
}
if (connect(sock, servinfo->ai_addr, servinfo->ai_addrlen) == -1) {
throw std::runtime_error("Failed to connect");
}
freeaddrinfo(servinfo);
}
class TcpInputStream : public InputStream {
public:
TcpInputStream(std::shared_ptr<TcpStream> tcpStream)
: tcpStream(tcpStream), bufferPos(0), bufferSize(0) {}
void readBytes(char *buffer, size_t byteCount) {
while (byteCount > 0) {
if (bufferSize > 0) {
if (bufferSize >= byteCount) {
memcpy(buffer, this->buffer + bufferPos, byteCount);
bufferPos += byteCount;
bufferSize -= byteCount;
return;
}
memcpy(buffer, this->buffer + bufferPos, bufferSize);
buffer += bufferSize;
byteCount -= bufferSize;
bufferPos += bufferSize;
bufferSize = 0;
}
if (bufferPos == BUFFER_CAPACITY) {
bufferPos = 0;
}
RECV_SEND_T received =
recv(tcpStream->sock, this->buffer + bufferPos + bufferSize,
BUFFER_CAPACITY - bufferPos - bufferSize, 0);
if (received < 0) {
throw std::runtime_error("Failed to read from socket");
}
bufferSize += received;
}
}
private:
static const size_t BUFFER_CAPACITY = 8 * 1024;
char buffer[BUFFER_CAPACITY];
size_t bufferPos;
size_t bufferSize;
std::shared_ptr<TcpStream> tcpStream;
};
class TcpOutputStream : public OutputStream {
public:
TcpOutputStream(std::shared_ptr<TcpStream> tcpStream)
: tcpStream(tcpStream), bufferPos(0), bufferSize(0) {}
void writeBytes(const char *buffer, size_t byteCount) {
while (byteCount > 0) {
size_t capacity = BUFFER_CAPACITY - bufferPos - bufferSize;
if (capacity >= byteCount) {
memcpy(this->buffer + bufferPos + bufferSize, buffer, byteCount);
bufferSize += byteCount;
return;
}
memcpy(this->buffer + bufferPos + bufferSize, buffer, capacity);
bufferSize += capacity;
byteCount -= capacity;
buffer += capacity;
flush();
}
}
void flush() {
while (bufferSize > 0) {
RECV_SEND_T sent =
send(tcpStream->sock, buffer + bufferPos, bufferSize, 0);
if (sent < 0) {
throw std::runtime_error("Failed to write to socket");
}
bufferPos += sent;
bufferSize -= sent;
}
bufferPos = 0;
}
private:
static const size_t BUFFER_CAPACITY = 8 * 1024;
char buffer[BUFFER_CAPACITY];
size_t bufferPos;
size_t bufferSize;
std::shared_ptr<TcpStream> tcpStream;
};
std::shared_ptr<InputStream>
getInputStream(std::shared_ptr<TcpStream> tcpStream) {
return std::shared_ptr<TcpInputStream>(new TcpInputStream(tcpStream));
}
std::shared_ptr<OutputStream>
getOutputStream(std::shared_ptr<TcpStream> tcpStream) {
return std::shared_ptr<TcpOutputStream>(new TcpOutputStream(tcpStream));
}