@@ -24,6 +24,8 @@ type testTracer struct {
2424 traceConnectEnd func (ctx context.Context , data pgx.TraceConnectEndData )
2525}
2626
27+ type ctxKey string
28+
2729func (tt * testTracer ) TraceQueryStart (ctx context.Context , conn * pgx.Conn , data pgx.TraceQueryStartData ) context.Context {
2830 if tt .traceQueryStart != nil {
2931 return tt .traceQueryStart (ctx , conn , data )
@@ -117,13 +119,13 @@ func TestTraceExec(t *testing.T) {
117119 require .Equal (t , `select $1::text` , data .SQL )
118120 require .Len (t , data .Args , 1 )
119121 require .Equal (t , `testing` , data .Args [0 ])
120- return context .WithValue (ctx , "fromTraceQueryStart" , "foo" )
122+ return context .WithValue (ctx , ctxKey ( ctxKey ( "fromTraceQueryStart" )) , "foo" )
121123 }
122124
123125 traceQueryEndCalled := false
124126 tracer .traceQueryEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceQueryEndData ) {
125127 traceQueryEndCalled = true
126- require .Equal (t , "foo" , ctx .Value ("fromTraceQueryStart" ))
128+ require .Equal (t , "foo" , ctx .Value (ctxKey ( ctxKey ( "fromTraceQueryStart" )) ))
127129 require .Equal (t , `SELECT 1` , data .CommandTag .String ())
128130 require .NoError (t , data .Err )
129131 }
@@ -157,13 +159,13 @@ func TestTraceQuery(t *testing.T) {
157159 require .Equal (t , `select $1::text` , data .SQL )
158160 require .Len (t , data .Args , 1 )
159161 require .Equal (t , `testing` , data .Args [0 ])
160- return context .WithValue (ctx , "fromTraceQueryStart" , "foo" )
162+ return context .WithValue (ctx , ctxKey ( "fromTraceQueryStart" ) , "foo" )
161163 }
162164
163165 traceQueryEndCalled := false
164166 tracer .traceQueryEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceQueryEndData ) {
165167 traceQueryEndCalled = true
166- require .Equal (t , "foo" , ctx .Value ("fromTraceQueryStart" ))
168+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceQueryStart" ) ))
167169 require .Equal (t , `SELECT 1` , data .CommandTag .String ())
168170 require .NoError (t , data .Err )
169171 }
@@ -198,20 +200,20 @@ func TestTraceBatchNormal(t *testing.T) {
198200 traceBatchStartCalled = true
199201 require .NotNil (t , data .Batch )
200202 require .Equal (t , 2 , data .Batch .Len ())
201- return context .WithValue (ctx , "fromTraceBatchStart" , "foo" )
203+ return context .WithValue (ctx , ctxKey ( "fromTraceBatchStart" ) , "foo" )
202204 }
203205
204206 traceBatchQueryCalledCount := 0
205207 tracer .traceBatchQuery = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchQueryData ) {
206208 traceBatchQueryCalledCount ++
207- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
209+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
208210 require .NoError (t , data .Err )
209211 }
210212
211213 traceBatchEndCalled := false
212214 tracer .traceBatchEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchEndData ) {
213215 traceBatchEndCalled = true
214- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
216+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
215217 require .NoError (t , data .Err )
216218 }
217219
@@ -261,20 +263,20 @@ func TestTraceBatchClose(t *testing.T) {
261263 traceBatchStartCalled = true
262264 require .NotNil (t , data .Batch )
263265 require .Equal (t , 2 , data .Batch .Len ())
264- return context .WithValue (ctx , "fromTraceBatchStart" , "foo" )
266+ return context .WithValue (ctx , ctxKey ( "fromTraceBatchStart" ) , "foo" )
265267 }
266268
267269 traceBatchQueryCalledCount := 0
268270 tracer .traceBatchQuery = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchQueryData ) {
269271 traceBatchQueryCalledCount ++
270- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
272+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
271273 require .NoError (t , data .Err )
272274 }
273275
274276 traceBatchEndCalled := false
275277 tracer .traceBatchEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchEndData ) {
276278 traceBatchEndCalled = true
277- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
279+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
278280 require .NoError (t , data .Err )
279281 }
280282
@@ -312,13 +314,13 @@ func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
312314 traceBatchStartCalled = true
313315 require .NotNil (t , data .Batch )
314316 require .Equal (t , 3 , data .Batch .Len ())
315- return context .WithValue (ctx , "fromTraceBatchStart" , "foo" )
317+ return context .WithValue (ctx , ctxKey ( "fromTraceBatchStart" ) , "foo" )
316318 }
317319
318320 traceBatchQueryCalledCount := 0
319321 tracer .traceBatchQuery = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchQueryData ) {
320322 traceBatchQueryCalledCount ++
321- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
323+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
322324 if traceBatchQueryCalledCount == 2 {
323325 require .Error (t , data .Err )
324326 } else {
@@ -329,7 +331,7 @@ func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
329331 traceBatchEndCalled := false
330332 tracer .traceBatchEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchEndData ) {
331333 traceBatchEndCalled = true
332- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
334+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
333335 require .Error (t , data .Err )
334336 }
335337
@@ -381,13 +383,13 @@ func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
381383 traceBatchStartCalled = true
382384 require .NotNil (t , data .Batch )
383385 require .Equal (t , 3 , data .Batch .Len ())
384- return context .WithValue (ctx , "fromTraceBatchStart" , "foo" )
386+ return context .WithValue (ctx , ctxKey ( "fromTraceBatchStart" ) , "foo" )
385387 }
386388
387389 traceBatchQueryCalledCount := 0
388390 tracer .traceBatchQuery = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchQueryData ) {
389391 traceBatchQueryCalledCount ++
390- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
392+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
391393 if traceBatchQueryCalledCount == 2 {
392394 require .Error (t , data .Err )
393395 } else {
@@ -398,7 +400,7 @@ func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
398400 traceBatchEndCalled := false
399401 tracer .traceBatchEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchEndData ) {
400402 traceBatchEndCalled = true
401- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
403+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
402404 require .Error (t , data .Err )
403405 }
404406
@@ -440,13 +442,13 @@ func TestTraceCopyFrom(t *testing.T) {
440442 traceCopyFromStartCalled = true
441443 require .Equal (t , pgx.Identifier {"foo" }, data .TableName )
442444 require .Equal (t , []string {"a" }, data .ColumnNames )
443- return context .WithValue (ctx , "fromTraceCopyFromStart" , "foo" )
445+ return context .WithValue (ctx , ctxKey ( "fromTraceCopyFromStart" ) , "foo" )
444446 }
445447
446448 traceCopyFromEndCalled := false
447449 tracer .traceCopyFromEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceCopyFromEndData ) {
448450 traceCopyFromEndCalled = true
449- require .Equal (t , "foo" , ctx .Value ("fromTraceCopyFromStart" ))
451+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceCopyFromStart" ) ))
450452 require .Equal (t , `COPY 2` , data .CommandTag .String ())
451453 require .NoError (t , data .Err )
452454 }
@@ -488,7 +490,7 @@ func TestTracePrepare(t *testing.T) {
488490 tracePrepareStartCalled = true
489491 require .Equal (t , `ps` , data .Name )
490492 require .Equal (t , `select $1::text` , data .SQL )
491- return context .WithValue (ctx , "fromTracePrepareStart" , "foo" )
493+ return context .WithValue (ctx , ctxKey ( "fromTracePrepareStart" ) , "foo" )
492494 }
493495
494496 tracePrepareEndCalled := false
@@ -530,7 +532,7 @@ func TestTraceConnect(t *testing.T) {
530532 tracer .traceConnectStart = func (ctx context.Context , data pgx.TraceConnectStartData ) context.Context {
531533 traceConnectStartCalled = true
532534 require .NotNil (t , data .ConnConfig )
533- return context .WithValue (ctx , "fromTraceConnectStart" , "foo" )
535+ return context .WithValue (ctx , ctxKey ( "fromTraceConnectStart" ) , "foo" )
534536 }
535537
536538 traceConnectEndCalled := false
0 commit comments