summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/base/vnc/vncserver.cc145
-rw-r--r--src/base/vnc/vncserver.hh18
2 files changed, 89 insertions, 74 deletions
diff --git a/src/base/vnc/vncserver.cc b/src/base/vnc/vncserver.cc
index 216fa2fb4..9cf38dc2d 100644
--- a/src/base/vnc/vncserver.cc
+++ b/src/base/vnc/vncserver.cc
@@ -198,7 +198,10 @@ VncServer::accept()
panic("%s: cannot accept a connection if not listening!", name());
int fd = listener.accept(true);
- fatal_if(fd < 0, "%s: failed to accept VNC connection!", name());
+ if (fd < 0) {
+ warn("%s: failed to accept VNC connection!", name());
+ return;
+ }
if (dataFd != -1) {
char message[] = "vnc server already attached!\n";
@@ -210,7 +213,7 @@ VncServer::accept()
dataFd = fd;
// Send our version number to the client
- write((uint8_t*)vncVersion(), strlen(vncVersion()));
+ write((uint8_t *)vncVersion(), strlen(vncVersion()));
// read the client response
dataEvent = new DataEvent(this, dataFd, POLLIN);
@@ -224,7 +227,6 @@ void
VncServer::data()
{
// We have new data, see if we can handle it
- size_t len;
DPRINTF(VNC, "Vnc client message recieved\n");
switch (curState) {
@@ -237,8 +239,8 @@ VncServer::data()
case WaitForClientInit:
// Don't care about shared, just need to read it out of the socket
uint8_t shared;
- len = read(&shared);
- assert(len == 1);
+ if (!read(&shared))
+ return;
// Send our idea of the frame buffer
sendServerInit();
@@ -246,12 +248,8 @@ VncServer::data()
break;
case NormalPhase:
uint8_t message_type;
- len = read(&message_type);
- if (!len) {
- detach();
+ if (!read(&message_type))
return;
- }
- assert(len == 1);
switch (message_type) {
case ClientSetPixelFormat:
@@ -273,8 +271,9 @@ VncServer::data()
recvCutText();
break;
default:
- panic("Unimplemented message type recv from client: %d\n",
- message_type);
+ warn("Unimplemented message type recv from client: %d\n",
+ message_type);
+ detach();
break;
}
break;
@@ -285,7 +284,7 @@ VncServer::data()
// read from socket
-size_t
+bool
VncServer::read(uint8_t *buf, size_t len)
{
if (dataFd < 0)
@@ -297,59 +296,58 @@ VncServer::read(uint8_t *buf, size_t len)
} while (ret == -1 && errno == EINTR);
- if (ret <= 0){
- DPRINTF(VNC, "Read failed.\n");
+ if (ret != len) {
+ DPRINTF(VNC, "Read failed %d.\n", ret);
detach();
- return 0;
+ return false;
}
- return ret;
+ return true;
}
-size_t
+bool
VncServer::read1(uint8_t *buf, size_t len)
{
- size_t read_len M5_VAR_USED;
- read_len = read(buf + 1, len - 1);
- assert(read_len == len - 1);
- return read_len;
+ return read(buf + 1, len - 1);
}
template<typename T>
-size_t
+bool
VncServer::read(T* val)
{
- return read((uint8_t*)val, sizeof(T));
+ return read((uint8_t *)val, sizeof(T));
}
// write to socket
-size_t
+bool
VncServer::write(const uint8_t *buf, size_t len)
{
if (dataFd < 0)
panic("Vnc client not properly attached.\n");
- ssize_t ret;
- ret = atomic_write(dataFd, buf, len);
+ ssize_t ret = atomic_write(dataFd, buf, len);
- if (ret < len)
+ if (ret != len) {
+ DPRINTF(VNC, "Write failed.\n");
detach();
+ return false;
+ }
- return ret;
+ return true;
}
template<typename T>
-size_t
+bool
VncServer::write(T* val)
{
- return write((uint8_t*)val, sizeof(T));
+ return write((uint8_t *)val, sizeof(T));
}
-size_t
+bool
VncServer::write(const char* str)
{
- return write((uint8_t*)str, strlen(str));
+ return write((uint8_t *)str, strlen(str));
}
// detach a vnc client
@@ -377,7 +375,8 @@ void
VncServer::sendError(const char* error_msg)
{
uint32_t len = strlen(error_msg);
- write(&len);
+ if (!write(&len))
+ return;
write(error_msg);
}
@@ -392,8 +391,10 @@ VncServer::checkProtocolVersion()
// Null terminate the message so it's easier to work with
version_string[12] = 0;
- len = read((uint8_t*)version_string, 12);
- assert(len == 12);
+ if (!read((uint8_t *)version_string, sizeof(version_string) - 1)) {
+ warn("Failed to read protocol version.");
+ return;
+ }
uint32_t major, minor;
@@ -402,6 +403,7 @@ VncServer::checkProtocolVersion()
warn(" Malformed protocol version %s\n", version_string);
sendError("Malformed protocol version\n");
detach();
+ return;
}
DPRINTF(VNC, "Client request protocol version %d.%d\n", major, minor);
@@ -412,16 +414,18 @@ VncServer::checkProtocolVersion()
uint8_t err = AuthInvalid;
write(&err);
detach();
+ return;
}
// Auth is different based on version number
if (minor < 7) {
uint32_t sec_type = htobe((uint32_t)AuthNone);
- write(&sec_type);
+ if (!write(&sec_type))
+ return;
} else {
uint8_t sec_cnt = 1;
uint8_t sec_type = htobe((uint8_t)AuthNone);
- write(&sec_cnt);
- write(&sec_type);
+ if (!write(&sec_cnt) || !write(&sec_type))
+ return;
}
// Wait for client to respond
@@ -434,9 +438,8 @@ VncServer::checkSecurity()
assert(curState == WaitForSecurityResponse);
uint8_t security_type;
- size_t len M5_VAR_USED = read(&security_type);
-
- assert(len == 1);
+ if (!read(&security_type))
+ return;
if (security_type != AuthNone) {
warn("Unknown VNC security type\n");
@@ -446,7 +449,8 @@ VncServer::checkSecurity()
DPRINTF(VNC, "Sending security auth OK\n");
uint32_t success = htobe(VncOK);
- write(&success);
+ if (!write(&success))
+ return;
curState = WaitForClientInit;
}
@@ -475,7 +479,8 @@ VncServer::sendServerInit()
msg.namelen = htobe(msg.namelen);
memcpy(msg.name, "M5", 2);
- write(&msg);
+ if (!write(&msg))
+ return;
curState = NormalPhase;
}
@@ -485,7 +490,8 @@ VncServer::setPixelFormat()
DPRINTF(VNC, "Received pixel format from client message\n");
PixelFormatMessage pfm;
- read1((uint8_t*)&pfm, sizeof(PixelFormatMessage));
+ if (!read1((uint8_t *)&pfm, sizeof(PixelFormatMessage)))
+ return;
DPRINTF(VNC, " -- bpp = %d; depth = %d; be = %d\n", pfm.px.bpp,
pfm.px.depth, pfm.px.bigendian);
@@ -504,8 +510,10 @@ VncServer::setPixelFormat()
betoh(pfm.px.bluemax) != pixelFormat.bluemax ||
betoh(pfm.px.redshift) != pixelFormat.redshift ||
betoh(pfm.px.greenshift) != pixelFormat.greenshift ||
- betoh(pfm.px.blueshift) != pixelFormat.blueshift)
- fatal("VNC client doesn't support true color raw encoding\n");
+ betoh(pfm.px.blueshift) != pixelFormat.blueshift) {
+ warn("VNC client doesn't support true color raw encoding\n");
+ detach();
+ }
}
void
@@ -514,7 +522,8 @@ VncServer::setEncodings()
DPRINTF(VNC, "Received supported encodings from client\n");
PixelEncodingsMessage pem;
- read1((uint8_t*)&pem, sizeof(PixelEncodingsMessage));
+ if (!read1((uint8_t *)&pem, sizeof(PixelEncodingsMessage)))
+ return;
pem.num_encodings = betoh(pem.num_encodings);
@@ -523,9 +532,8 @@ VncServer::setEncodings()
for (int x = 0; x < pem.num_encodings; x++) {
int32_t encoding;
- size_t len M5_VAR_USED;
- len = read(&encoding);
- assert(len == sizeof(encoding));
+ if (!read(&encoding))
+ return;
DPRINTF(VNC, " -- supports %d\n", betoh(encoding));
switch (betoh(encoding)) {
@@ -538,8 +546,10 @@ VncServer::setEncodings()
}
}
- if (!supportsRawEnc)
- fatal("VNC clients must always support raw encoding\n");
+ if (!supportsRawEnc) {
+ warn("VNC clients must always support raw encoding\n");
+ detach();
+ }
}
void
@@ -548,7 +558,8 @@ VncServer::requestFbUpdate()
DPRINTF(VNC, "Received frame buffer update request from client\n");
FrameBufferUpdateReq fbr;
- read1((uint8_t*)&fbr, sizeof(FrameBufferUpdateReq));
+ if (!read1((uint8_t *)&fbr, sizeof(FrameBufferUpdateReq)))
+ return;
fbr.x = betoh(fbr.x);
fbr.y = betoh(fbr.y);
@@ -566,7 +577,8 @@ VncServer::recvKeyboardInput()
{
DPRINTF(VNC, "Received keyboard input from client\n");
KeyEventMessage kem;
- read1((uint8_t*)&kem, sizeof(KeyEventMessage));
+ if (!read1((uint8_t *)&kem, sizeof(KeyEventMessage)))
+ return;
kem.key = betoh(kem.key);
DPRINTF(VNC, " -- received key code %d (%s)\n", kem.key, kem.down_flag ?
@@ -582,7 +594,8 @@ VncServer::recvPointerInput()
DPRINTF(VNC, "Received pointer input from client\n");
PointerEventMessage pem;
- read1((uint8_t*)&pem, sizeof(PointerEventMessage));;
+ if (!read1((uint8_t *)&pem, sizeof(PointerEventMessage)))
+ return;
pem.x = betoh(pem.x);
pem.y = betoh(pem.y);
@@ -599,18 +612,18 @@ VncServer::recvCutText()
DPRINTF(VNC, "Received client copy buffer message\n");
ClientCutTextMessage cct;
- read1((uint8_t*)&cct, sizeof(ClientCutTextMessage));
+ if (!read1((uint8_t *)&cct, sizeof(ClientCutTextMessage)))
+ return;
char str[1025];
size_t data_len = betoh(cct.length);
DPRINTF(VNC, "String length %d\n", data_len);
while (data_len > 0) {
- size_t len;
size_t bytes_to_read = data_len > 1024 ? 1024 : data_len;
- len = read((uint8_t*)&str, bytes_to_read);
+ if (!read((uint8_t *)&str, bytes_to_read))
+ return;
str[bytes_to_read] = 0;
- assert(len >= data_len);
- data_len -= len;
+ data_len -= bytes_to_read;
DPRINTF(VNC, "Buffer: %s\n", str);
}
@@ -651,8 +664,8 @@ VncServer::sendFrameBufferUpdate()
fbr.encoding = htobe(fbr.encoding);
// send headers to client
- write(&fbu);
- write(&fbr);
+ if (!write(&fbu) || !write(&fbr))
+ return;
assert(fb);
@@ -665,7 +678,8 @@ VncServer::sendFrameBufferUpdate()
raw_pixel += pixelConverter.length;
}
- write(line_buffer.data(), line_buffer.size());
+ if (!write(line_buffer.data(), line_buffer.size()))
+ return;
}
}
@@ -695,7 +709,8 @@ VncServer::sendFrameBufferResized()
fbr.encoding = htobe(fbr.encoding);
// send headers to client
- write(&fbu);
+ if (!write(&fbu))
+ return;
write(&fbr);
// No actual data is sent in this message
diff --git a/src/base/vnc/vncserver.hh b/src/base/vnc/vncserver.hh
index a52850323..99f4b5fe1 100644
--- a/src/base/vnc/vncserver.hh
+++ b/src/base/vnc/vncserver.hh
@@ -216,9 +216,9 @@ class VncServer : public VncInput
/** Read some data from the client
* @param buf the data to read
* @param len the amount of data to read
- * @return length read
+ * @return whether the read was successful
*/
- size_t read(uint8_t *buf, size_t len);
+ bool read(uint8_t *buf, size_t len);
/** Read len -1 bytes from the client into the buffer provided + 1
* assert that we read enough bytes. This function exists to handle
@@ -226,35 +226,35 @@ class VncServer : public VncInput
* the first byte which describes which one we're reading
* @param buf the address of the buffer to add one to and read data into
* @param len the amount of data + 1 to read
- * @return length read
+ * @return whether the read was successful.
*/
- size_t read1(uint8_t *buf, size_t len);
+ bool read1(uint8_t *buf, size_t len);
/** Templated version of the read function above to
* read simple data to the client
* @param val data to recv from the client
*/
- template <typename T> size_t read(T* val);
+ template <typename T> bool read(T* val);
/** Write a buffer to the client.
* @param buf buffer to send
* @param len length of the buffer
- * @return number of bytes sent
+ * @return whether the write was successful
*/
- size_t write(const uint8_t *buf, size_t len);
+ bool write(const uint8_t *buf, size_t len);
/** Templated version of the write function above to
* write simple data to the client
* @param val data to send to the client
*/
- template <typename T> size_t write(T* val);
+ template <typename T> bool write(T* val);
/** Send a string to the client
* @param str string to transmit
*/
- size_t write(const char* str);
+ bool write(const char* str);
/** Check the client's protocol verion for compatibility and send
* the security types we support