• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

topfreegames / pitaya / 5071265607

26 May 2023 11:31PM UTC coverage: 62.204% (+0.07%) from 62.131%
5071265607

push

github

web-flow
Add handshake validators

Add handshake validators

94 of 94 new or added lines in 4 files covered. (100.0%)

4776 of 7678 relevant lines covered (62.2%)

0.69 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.5
/service/handler.go
1
// Copyright (c) nano Author and TFG Co. All Rights Reserved.
2
//
3
// Permission is hereby granted, free of charge, to any person obtaining a copy
4
// of this software and associated documentation files (the "Software"), to deal
5
// in the Software without restriction, including without limitation the rights
6
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
// copies of the Software, and to permit persons to whom the Software is
8
// furnished to do so, subject to the following conditions:
9
//
10
// The above copyright notice and this permission notice shall be included in all
11
// copies or substantial portions of the Software.
12
//
13
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
// SOFTWARE.
20

21
package service
22

23
import (
24
        "context"
25
        "encoding/json"
26
        "fmt"
27
        "strings"
28
        "time"
29

30
        "github.com/nats-io/nuid"
31

32
        "github.com/topfreegames/pitaya/v2/acceptor"
33
        "github.com/topfreegames/pitaya/v2/pipeline"
34

35
        opentracing "github.com/opentracing/opentracing-go"
36
        "github.com/topfreegames/pitaya/v2/agent"
37
        "github.com/topfreegames/pitaya/v2/cluster"
38
        "github.com/topfreegames/pitaya/v2/component"
39
        "github.com/topfreegames/pitaya/v2/conn/codec"
40
        "github.com/topfreegames/pitaya/v2/conn/message"
41
        "github.com/topfreegames/pitaya/v2/conn/packet"
42
        "github.com/topfreegames/pitaya/v2/constants"
43
        pcontext "github.com/topfreegames/pitaya/v2/context"
44
        "github.com/topfreegames/pitaya/v2/docgenerator"
45
        e "github.com/topfreegames/pitaya/v2/errors"
46
        "github.com/topfreegames/pitaya/v2/logger"
47
        "github.com/topfreegames/pitaya/v2/metrics"
48
        "github.com/topfreegames/pitaya/v2/route"
49
        "github.com/topfreegames/pitaya/v2/serialize"
50
        "github.com/topfreegames/pitaya/v2/session"
51
        "github.com/topfreegames/pitaya/v2/timer"
52
        "github.com/topfreegames/pitaya/v2/tracing"
53
)
54

55
var (
56
        handlerType = "handler"
57
)
58

59
type (
60
        // HandlerService service
61
        HandlerService struct {
62
                baseService
63
                chLocalProcess   chan unhandledMessage // channel of messages that will be processed locally
64
                chRemoteProcess  chan unhandledMessage // channel of messages that will be processed remotely
65
                decoder          codec.PacketDecoder   // binary decoder
66
                remoteService    *RemoteService
67
                serializer       serialize.Serializer          // message serializer
68
                server           *cluster.Server               // server obj
69
                services         map[string]*component.Service // all registered service
70
                metricsReporters []metrics.Reporter
71
                agentFactory     agent.AgentFactory
72
                handlerPool      *HandlerPool
73
                handlers         map[string]*component.Handler // all handler method
74
        }
75

76
        unhandledMessage struct {
77
                ctx   context.Context
78
                agent agent.Agent
79
                route *route.Route
80
                msg   *message.Message
81
        }
82
)
83

84
// NewHandlerService creates and returns a new handler service
85
func NewHandlerService(
86
        packetDecoder codec.PacketDecoder,
87
        serializer serialize.Serializer,
88
        localProcessBufferSize int,
89
        remoteProcessBufferSize int,
90
        server *cluster.Server,
91
        remoteService *RemoteService,
92
        agentFactory agent.AgentFactory,
93
        metricsReporters []metrics.Reporter,
94
        handlerHooks *pipeline.HandlerHooks,
95
        handlerPool *HandlerPool,
96
) *HandlerService {
1✔
97
        h := &HandlerService{
1✔
98
                services:         make(map[string]*component.Service),
1✔
99
                chLocalProcess:   make(chan unhandledMessage, localProcessBufferSize),
1✔
100
                chRemoteProcess:  make(chan unhandledMessage, remoteProcessBufferSize),
1✔
101
                decoder:          packetDecoder,
1✔
102
                serializer:       serializer,
1✔
103
                server:           server,
1✔
104
                remoteService:    remoteService,
1✔
105
                agentFactory:     agentFactory,
1✔
106
                metricsReporters: metricsReporters,
1✔
107
                handlerPool:      handlerPool,
1✔
108
                handlers:         make(map[string]*component.Handler),
1✔
109
        }
1✔
110

1✔
111
        h.handlerHooks = handlerHooks
1✔
112

1✔
113
        return h
1✔
114
}
1✔
115

116
// Dispatch message to corresponding logic handler
117
func (h *HandlerService) Dispatch(thread int) {
×
118
        // TODO: This timer is being stopped multiple times, it probably doesn't need to be stopped here
×
119
        defer timer.GlobalTicker.Stop()
×
120

×
121
        for {
×
122
                // Calls to remote servers block calls to local server
×
123
                select {
×
124
                case lm := <-h.chLocalProcess:
×
125
                        metrics.ReportMessageProcessDelayFromCtx(lm.ctx, h.metricsReporters, "local")
×
126
                        h.localProcess(lm.ctx, lm.agent, lm.route, lm.msg)
×
127

128
                case rm := <-h.chRemoteProcess:
×
129
                        metrics.ReportMessageProcessDelayFromCtx(rm.ctx, h.metricsReporters, "remote")
×
130
                        h.remoteService.remoteProcess(rm.ctx, nil, rm.agent, rm.route, rm.msg)
×
131

132
                case <-timer.GlobalTicker.C: // execute cron task
×
133
                        timer.Cron()
×
134

135
                case t := <-timer.Manager.ChCreatedTimer: // new Timers
×
136
                        timer.AddTimer(t)
×
137

138
                case id := <-timer.Manager.ChClosingTimer: // closing Timers
×
139
                        timer.RemoveTimer(id)
×
140
                }
141
        }
142
}
143

144
// Register registers components
145
func (h *HandlerService) Register(comp component.Component, opts []component.Option) error {
1✔
146
        s := component.NewService(comp, opts)
1✔
147

1✔
148
        if _, ok := h.services[s.Name]; ok {
2✔
149
                return fmt.Errorf("handler: service already defined: %s", s.Name)
1✔
150
        }
1✔
151

152
        if err := s.ExtractHandler(); err != nil {
2✔
153
                return err
1✔
154
        }
1✔
155

156
        // register all handlers
157
        h.services[s.Name] = s
1✔
158
        for name, handler := range s.Handlers {
2✔
159
                h.handlerPool.Register(s.Name, name, handler)
1✔
160
        }
1✔
161
        return nil
1✔
162
}
163

164
// Handle handles messages from a conn
165
func (h *HandlerService) Handle(conn acceptor.PlayerConn) {
1✔
166
        // create a client agent and startup write goroutine
1✔
167
        a := h.agentFactory.CreateAgent(conn)
1✔
168

1✔
169
        // startup agent goroutine
1✔
170
        go a.Handle()
1✔
171

1✔
172
        logger.Log.Debugf("New session established: %s", a.String())
1✔
173

1✔
174
        // guarantee agent related resource is destroyed
1✔
175
        defer func() {
2✔
176
                a.GetSession().Close()
1✔
177
                logger.Log.Debugf("Session read goroutine exit, SessionID=%d, UID=%s", a.GetSession().ID(), a.GetSession().UID())
1✔
178
        }()
1✔
179

180
        for {
2✔
181
                msg, err := conn.GetNextMessage()
1✔
182

1✔
183
                if err != nil {
2✔
184
                        if err != constants.ErrConnectionClosed {
2✔
185
                                logger.Log.Errorf("Error reading next available message: %s", err.Error())
1✔
186
                        }
1✔
187

188
                        return
1✔
189
                }
190

191
                packets, err := h.decoder.Decode(msg)
1✔
192
                if err != nil {
1✔
193
                        logger.Log.Errorf("Failed to decode message: %s", err.Error())
×
194
                        return
×
195
                }
×
196

197
                if len(packets) < 1 {
1✔
198
                        logger.Log.Warnf("Read no packets, data: %v", msg)
×
199
                        continue
×
200
                }
201

202
                // process all packet
203
                for i := range packets {
2✔
204
                        if err := h.processPacket(a, packets[i]); err != nil {
1✔
205
                                logger.Log.Errorf("Failed to process packet: %s", err.Error())
×
206
                                return
×
207
                        }
×
208
                }
209
        }
210
}
211

212
func (h *HandlerService) processPacket(a agent.Agent, p *packet.Packet) error {
1✔
213
        switch p.Type {
1✔
214
        case packet.Handshake:
1✔
215
                logger.Log.Debug("Received handshake packet")
1✔
216

1✔
217
                // Parse the json sent with the handshake by the client
1✔
218
                handshakeData := &session.HandshakeData{}
1✔
219
                if err := json.Unmarshal(p.Data, handshakeData); err != nil {
2✔
220
                        logger.Log.Errorf("Failed to unmarshal handshake data: %s", err.Error())
1✔
221
                        if serr := a.SendHandshakeErrorResponse(); serr != nil {
1✔
222
                                logger.Log.Errorf("Error sending handshake error response: %s", err.Error())
×
223
                                return err
×
224
                        }
×
225

226
                        return fmt.Errorf("invalid handshake data. Id=%d", a.GetSession().ID())
1✔
227
                }
228

229
                if err := a.GetSession().ValidateHandshake(handshakeData); err != nil {
2✔
230
                        logger.Log.Errorf("Handshake validation failed: %s", err.Error())
1✔
231
                        if serr := a.SendHandshakeErrorResponse(); serr != nil {
1✔
232
                                logger.Log.Errorf("Error sending handshake error response: %s", err.Error())
×
233
                                return err
×
234
                        }
×
235

236
                        return fmt.Errorf("handshake validation failed: %w. SessionId=%d", err, a.GetSession().ID())
1✔
237
                }
238

239
                if err := a.SendHandshakeResponse(); err != nil {
1✔
240
                        logger.Log.Errorf("Error sending handshake response: %s", err.Error())
×
241
                        return err
×
242
                }
×
243
                logger.Log.Debugf("Session handshake Id=%d, Remote=%s", a.GetSession().ID(), a.RemoteAddr())
1✔
244

1✔
245
                a.GetSession().SetHandshakeData(handshakeData)
1✔
246
                a.SetStatus(constants.StatusHandshake)
1✔
247
                err := a.GetSession().Set(constants.IPVersionKey, a.IPVersion())
1✔
248
                if err != nil {
1✔
249
                        logger.Log.Warnf("failed to save ip version on session: %q\n", err)
×
250
                }
×
251

252
                logger.Log.Debug("Successfully saved handshake data")
1✔
253

254
        case packet.HandshakeAck:
1✔
255
                a.SetStatus(constants.StatusWorking)
1✔
256
                logger.Log.Debugf("Receive handshake ACK Id=%d, Remote=%s", a.GetSession().ID(), a.RemoteAddr())
1✔
257

258
        case packet.Data:
1✔
259
                if a.GetStatus() < constants.StatusWorking {
2✔
260
                        return fmt.Errorf("receive data on socket which is not yet ACK, session will be closed immediately, remote=%s",
1✔
261
                                a.RemoteAddr().String())
1✔
262
                }
1✔
263

264
                msg, err := message.Decode(p.Data)
1✔
265
                if err != nil {
2✔
266
                        return err
1✔
267
                }
1✔
268
                h.processMessage(a, msg)
1✔
269

270
        case packet.Heartbeat:
1✔
271
                // expected
272
        }
273

274
        a.SetLastAt()
1✔
275
        return nil
1✔
276
}
277

278
func (h *HandlerService) processMessage(a agent.Agent, msg *message.Message) {
1✔
279
        requestID := nuid.New().Next()
1✔
280
        ctx := pcontext.AddToPropagateCtx(context.Background(), constants.StartTimeKey, time.Now().UnixNano())
1✔
281
        ctx = pcontext.AddToPropagateCtx(ctx, constants.RouteKey, msg.Route)
1✔
282
        ctx = pcontext.AddToPropagateCtx(ctx, constants.RequestIDKey, requestID)
1✔
283
        tags := opentracing.Tags{
1✔
284
                "local.id":   h.server.ID,
1✔
285
                "span.kind":  "server",
1✔
286
                "msg.type":   strings.ToLower(msg.Type.String()),
1✔
287
                "user.id":    a.GetSession().UID(),
1✔
288
                "request.id": requestID,
1✔
289
        }
1✔
290
        ctx = tracing.StartSpan(ctx, msg.Route, tags)
1✔
291
        ctx = context.WithValue(ctx, constants.SessionCtxKey, a.GetSession())
1✔
292

1✔
293
        r, err := route.Decode(msg.Route)
1✔
294
        if err != nil {
2✔
295
                logger.Log.Errorf("Failed to decode route: %s", err.Error())
1✔
296
                a.AnswerWithError(ctx, msg.ID, e.NewError(err, e.ErrBadRequestCode))
1✔
297
                return
1✔
298
        }
1✔
299

300
        if r.SvType == "" {
2✔
301
                r.SvType = h.server.Type
1✔
302
        }
1✔
303

304
        message := unhandledMessage{
1✔
305
                ctx:   ctx,
1✔
306
                agent: a,
1✔
307
                route: r,
1✔
308
                msg:   msg,
1✔
309
        }
1✔
310
        if r.SvType == h.server.Type {
2✔
311
                h.chLocalProcess <- message
1✔
312
        } else {
2✔
313
                if h.remoteService != nil {
2✔
314
                        h.chRemoteProcess <- message
1✔
315
                } else {
1✔
316
                        logger.Log.Warnf("request made to another server type but no remoteService running")
×
317
                }
×
318
        }
319
}
320

321
func (h *HandlerService) localProcess(ctx context.Context, a agent.Agent, route *route.Route, msg *message.Message) {
1✔
322
        var mid uint
1✔
323
        switch msg.Type {
1✔
324
        case message.Request:
1✔
325
                mid = msg.ID
1✔
326
        case message.Notify:
×
327
                mid = 0
×
328
        }
329

330
        ret, err := h.handlerPool.ProcessHandlerMessage(ctx, route, h.serializer, h.handlerHooks, a.GetSession(), msg.Data, msg.Type, false)
1✔
331
        if msg.Type != message.Notify {
2✔
332
                if err != nil {
2✔
333
                        logger.Log.Errorf("Failed to process handler message: %s", err.Error())
1✔
334
                        a.AnswerWithError(ctx, mid, err)
1✔
335
                } else {
2✔
336
                        err := a.GetSession().ResponseMID(ctx, mid, ret)
1✔
337
                        if err != nil {
1✔
338
                                tracing.FinishSpan(ctx, err)
×
339
                                metrics.ReportTimingFromCtx(ctx, h.metricsReporters, handlerType, err)
×
340
                        }
×
341
                }
342
        } else {
×
343
                metrics.ReportTimingFromCtx(ctx, h.metricsReporters, handlerType, nil)
×
344
                tracing.FinishSpan(ctx, err)
×
345
        }
×
346
}
347

348
// DumpServices outputs all registered services
349
func (h *HandlerService) DumpServices() {
×
350
        handlers := h.handlerPool.GetHandlers()
×
351
        for name := range handlers {
×
352
                logger.Log.Infof("registered handler %s, isRawArg: %v", name, handlers[name].IsRawArg)
×
353
        }
×
354
}
355

356
// Docs returns documentation for handlers
357
func (h *HandlerService) Docs(getPtrNames bool) (map[string]interface{}, error) {
×
358
        if h == nil {
×
359
                return map[string]interface{}{}, nil
×
360
        }
×
361
        return docgenerator.HandlersDocs(h.server.Type, h.services, getPtrNames)
×
362
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc