-
Notifications
You must be signed in to change notification settings - Fork 173
/
handshake.go
221 lines (189 loc) · 6.17 KB
/
handshake.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
// Copyright 2016 VMware, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//
// Client:
// generate a random uint8 (#)
// send 2 bytes Syn|#
//
// Server:
// generate a random uint8 (&)
// read at least 2 bytes and make sure Syn|# is received
// send 3 bytes Ack|#+1|& (or Nak)
//
// Client:
// read at least 3 bytes and make sure Ack|#+1|& is received
// send 2 bytes Ack|&+1
//
// Server:
// read at least 2 bytes and make sure Ack|&+1 is received
// send 1 byte Ack (or Nak)
// Client:
// read at leat 1 byte and make sure Ack is received
package serial
import (
"fmt"
"io"
"math"
"math/rand"
"time"
log "github.com/Sirupsen/logrus"
"github.com/vmware/vic/pkg/trace"
)
const (
flagSyn byte = 0x16
flagAck = 0x06
flagNak = 0x15
)
// HandshakeError should only occure if the protocol between HandshakeServer and HandshakeClient was violated.
type HandshakeError struct {
msg string
}
func (he *HandshakeError) Error() string {
return he.msg
}
func init() {
rand.Seed(time.Now().UTC().UnixNano())
}
// ReadAtLeastN reads at least l bytes and returns those l bytes or errors
// We get lots of garbage data when we get the initial connection which handshake supposed to clear them and leave the connection in a known state so that the real ssh handshake can start.
// Client and server is looping with different frequencies so client could send multiple requests before server even had a chance to read.
// By getting the last l bytes we are saying that we are not interested with garbage data and also eliminating duplicated flags by only using the last one
func ReadAtLeastN(conn io.ReadWriter, buffer []byte, l int) ([]byte, error) {
n, err := io.ReadAtLeast(conn, buffer, l)
if err != nil {
return nil, err
}
// however if we read more than l, it means buffer is not empty
if n != l {
buffer = buffer[n-l:]
}
return buffer, nil
}
// HandshakeClient establishes connection with the server making sure
// they both are in sync.
func HandshakeClient(conn io.ReadWriter) error {
if tracing {
defer trace.End(trace.Begin(""))
}
// generate a random pos between [0, math.MaxUint8)
pos := uint8(rand.Intn(math.MaxUint8))
buffer := make([]byte, 32*1024)
// send syn with pos
log.Debugf("HandshakeClient: Sending syn with pos %d", pos)
if _, err := conn.Write([]byte{flagSyn, pos}); err != nil {
log.Errorf("syn: write failed")
return err
}
// read ack with pos+1 and token
buffer, err := ReadAtLeastN(conn, buffer, 3)
if err != nil {
return err
}
// extract pos and the token from it
flag, posack, token := uint8(buffer[0]), uint8(buffer[1]), uint8(buffer[2])
if flag == flagNak {
return &HandshakeError{
msg: "HandshakeClient: Server declined handshake request",
}
}
if flag != flagAck {
return &HandshakeError{
msg: fmt.Sprintf("HandshakeClient: Unexpected server response: %#v", flag),
}
}
if posack != pos+1 {
return &HandshakeError{
msg: fmt.Sprintf("HandshakeClient: Unexpected ack position: %d, expected %d", posack, pos+1),
}
}
log.Debugf("HandshakeClient: Sending ack with %d", token+1)
if _, err := conn.Write([]byte{flagAck, token + 1}); err != nil {
return err
}
// last ack packet is 1 byte and could be followed by SSH handshake so read only 1 byteand leave the rest in the net.Conn buffer
buffer = buffer[:1]
if _, err := conn.Read(buffer); err != nil {
return err
}
if buffer[0] != flagAck {
return &HandshakeError{
msg: fmt.Sprintf("HandshakeClient: Unexpected server response: %#v", flag),
}
}
log.Debug("HandshakeClient: Connection established.")
return nil
}
// HandshakeServer establishes connection with the client making sure
// they both are in sync.
func HandshakeServer(conn io.ReadWriter) error {
if tracing {
defer trace.End(trace.Begin(""))
}
// generate a random pos between [0, math.MaxUint8)
pos := uint8(rand.Intn(math.MaxUint8))
buffer := make([]byte, 32*1024)
log.Debug("HandshakeServer: Waiting for incoming syn request...")
// Sync packet is 2 bytes, however if we read more than 2 it means buffer is not empty and data is not trusted for this sync.
buffer, err := ReadAtLeastN(conn, buffer, 2)
if err != nil {
return err
}
// Read 2 bytes, extract flag and the token from it
flag, token := uint8(buffer[0]), uint8(buffer[1])
if flag != flagSyn {
if _, err := conn.Write([]byte{flagNak}); err != nil {
return err
}
return &HandshakeError{
msg: fmt.Sprintf("Unexpected syn packet: %x", flag),
}
}
log.Debugf("HandshakeServer: Received syn with pos %d. Writing syn-ack with %d and %d", token, token+1, pos)
// token contains position token that needs to be incremented by one to send it back.
if _, err := conn.Write([]byte{flagAck, token + 1, pos}); err != nil {
return err
}
// ACK packet is 2 bytes, however if we read more than 2 it means buffer is not empty and data is not trusted for this sync.
buffer, err = ReadAtLeastN(conn, buffer, 2)
if err != nil {
return err
}
// Read 2 bytes, extract flag and the token from it
flag, token = uint8(buffer[0]), uint8(buffer[1])
if flag != flagAck {
if _, err := conn.Write([]byte{flagNak}); err != nil {
return err
}
return &HandshakeError{
msg: fmt.Sprintf("Unexpected syn packet: %x", flag),
}
}
// token should contain incremented pos
if token != pos+1 {
if _, err := conn.Write([]byte{flagNak}); err != nil {
return err
}
return &HandshakeError{
msg: fmt.Sprintf("HandshakeServer: Unexpected position %x, expected: %x", token, pos+1),
}
}
log.Debugf("HandshakeServer: Received ACK with %d.", token)
// send the last ACK
if _, err := conn.Write([]byte{flagAck}); err != nil {
return err
}
log.Debug("HandshakeServer: Connection established.")
return nil
}