1
0
Fork 0
mirror of https://github.com/miniflux/v2.git synced 2025-06-27 16:36:00 +00:00

Add FeedIcon API call and update dependencies

This commit is contained in:
Frédéric Guillot 2017-12-16 11:25:18 -08:00
parent 231ebf2daa
commit 27196589fb
262 changed files with 83830 additions and 30061 deletions

22
Gopkg.lock generated
View file

@ -41,55 +41,55 @@
branch = "master" branch = "master"
name = "github.com/miniflux/miniflux-go" name = "github.com/miniflux/miniflux-go"
packages = ["."] packages = ["."]
revision = "c5788cd2d2248ee9fc148f3852dda7e24fe54cfa" revision = "ecd111d16e0ce1468cb3b786135c18b3fdc96213"
[[projects]] [[projects]]
name = "github.com/tdewolff/minify" name = "github.com/tdewolff/minify"
packages = [".","css","js"] packages = [".","css","js"]
revision = "90df1aae5028a7cbb441bde86e86a55df6b5aa34" revision = "222672169d634c440a73abc47685074e1a9daa60"
version = "v2.3.3" version = "v2.3.4"
[[projects]] [[projects]]
name = "github.com/tdewolff/parse" name = "github.com/tdewolff/parse"
packages = [".","buffer","css","js","strconv"] packages = [".","buffer","css","js","strconv"]
revision = "bace4cf682c41e03b154044b561575ff541b83e8" revision = "639f6272aec6b52094db77b9ec488214b0b4b1a1"
version = "v2.3.1" version = "v2.3.2"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "github.com/tomasen/realip" name = "github.com/tomasen/realip"
packages = ["."] packages = ["."]
revision = "15489afd3be348430f5f67467d2bb6b2f9b757ed" revision = "b5850897b7b539a1c9f22cdaa3b547d1bd453db8"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "golang.org/x/crypto" name = "golang.org/x/crypto"
packages = ["acme","acme/autocert","bcrypt","blowfish","ssh/terminal"] packages = ["acme","acme/autocert","bcrypt","blowfish","ssh/terminal"]
revision = "b080dc9a8c480b08e698fb1219160d598526310f" revision = "94eea52f7b742c7cbe0b03b22f0c4c8631ece122"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "golang.org/x/net" name = "golang.org/x/net"
packages = ["context","context/ctxhttp","html","html/atom","html/charset"] packages = ["context","context/ctxhttp","html","html/atom","html/charset"]
revision = "c7086645de248775cbf2373cf5ca4d2fa664b8c1" revision = "d866cfc389cec985d6fda2859936a575a55a3ab6"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "golang.org/x/oauth2" name = "golang.org/x/oauth2"
packages = [".","internal"] packages = [".","internal"]
revision = "f95fa95eaa936d9d87489b15d1d18b97c1ba9c28" revision = "462316686f20eb6df426961c1c131bdaa5dfa68e"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "golang.org/x/sys" name = "golang.org/x/sys"
packages = ["unix","windows"] packages = ["unix","windows"]
revision = "4ff8c001ce4cc464e644b922325097228fce14d8" revision = "571f7bbbe08da2a8955aed9d4db316e78630e9a3"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "golang.org/x/text" name = "golang.org/x/text"
packages = ["encoding","encoding/charmap","encoding/htmlindex","encoding/internal","encoding/internal/identifier","encoding/japanese","encoding/korean","encoding/simplifiedchinese","encoding/traditionalchinese","encoding/unicode","internal/gen","internal/tag","internal/utf8internal","language","runes","transform","unicode/cldr"] packages = ["encoding","encoding/charmap","encoding/htmlindex","encoding/internal","encoding/internal/identifier","encoding/japanese","encoding/korean","encoding/simplifiedchinese","encoding/traditionalchinese","encoding/unicode","internal/gen","internal/tag","internal/utf8internal","language","runes","transform","unicode/cldr"]
revision = "88f656faf3f37f690df1a32515b479415e1a6769" revision = "d5a9226ed7dd70cade6ccae9d37517fe14dd9fee"
[[projects]] [[projects]]
name = "google.golang.org/appengine" name = "google.golang.org/appengine"

View file

@ -21,6 +21,7 @@ const (
testAdminUsername = "admin" testAdminUsername = "admin"
testAdminPassword = "test123" testAdminPassword = "test123"
testStandardPassword = "secret" testStandardPassword = "secret"
testFeedURL = "https://github.com/miniflux/miniflux/commits/master.atom"
) )
func TestWithBadEndpoint(t *testing.T) { func TestWithBadEndpoint(t *testing.T) {
@ -714,6 +715,57 @@ func TestGetFeed(t *testing.T) {
} }
} }
func TestGetFeedIcon(t *testing.T) {
username := getRandomUsername()
client := miniflux.NewClient(testBaseURL, testAdminUsername, testAdminPassword)
_, err := client.CreateUser(username, testStandardPassword, false)
if err != nil {
t.Fatal(err)
}
client = miniflux.NewClient(testBaseURL, username, testStandardPassword)
categories, err := client.Categories()
if err != nil {
t.Fatal(err)
}
feedID, err := client.CreateFeed(testFeedURL, categories[0].ID)
if err != nil {
t.Fatal(err)
}
feedIcon, err := client.FeedIcon(feedID)
if err != nil {
t.Fatal(err)
}
if feedIcon.ID == 0 {
t.Fatalf(`Invalid feed icon ID, got "%v"`, feedIcon.ID)
}
if feedIcon.MimeType != "image/x-icon" {
t.Fatalf(`Invalid feed icon mime type, got "%v" instead of "%v"`, feedIcon.MimeType, "image/x-icon")
}
if !strings.Contains(feedIcon.Data, "image/x-icon") {
t.Fatalf(`Invalid feed icon data, got "%v"`, feedIcon.Data)
}
}
func TestGetFeedIconNotFound(t *testing.T) {
username := getRandomUsername()
client := miniflux.NewClient(testBaseURL, testAdminUsername, testAdminPassword)
_, err := client.CreateUser(username, testStandardPassword, false)
if err != nil {
t.Fatal(err)
}
client = miniflux.NewClient(testBaseURL, username, testStandardPassword)
if _, err := client.FeedIcon(42); err == nil {
t.Fatalf(`The feed icon should be null`)
}
}
func TestGetFeeds(t *testing.T) { func TestGetFeeds(t *testing.T) {
username := getRandomUsername() username := getRandomUsername()
client := miniflux.NewClient(testBaseURL, testAdminUsername, testAdminPassword) client := miniflux.NewClient(testBaseURL, testAdminUsername, testAdminPassword)

View file

@ -0,0 +1,44 @@
// Copyright 2017 Frédéric Guillot. All rights reserved.
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
package api
import (
"errors"
"github.com/miniflux/miniflux/server/api/payload"
"github.com/miniflux/miniflux/server/core"
)
// FeedIcon returns a feed icon.
func (c *Controller) FeedIcon(ctx *core.Context, request *core.Request, response *core.Response) {
userID := ctx.UserID()
feedID, err := request.IntegerParam("feedID")
if err != nil {
response.JSON().BadRequest(err)
return
}
if !c.store.HasIcon(feedID) {
response.JSON().NotFound(errors.New("This feed doesn't have any icon"))
return
}
icon, err := c.store.IconByFeedID(userID, feedID)
if err != nil {
response.JSON().ServerError(errors.New("Unable to fetch feed icon"))
return
}
if icon == nil {
response.JSON().NotFound(errors.New("This feed doesn't have any icon"))
return
}
response.JSON().Standard(&payload.FeedIcon{
ID: icon.ID,
MimeType: icon.MimeType,
Data: icon.DataURL(),
})
}

View file

@ -12,6 +12,13 @@ import (
"github.com/miniflux/miniflux/model" "github.com/miniflux/miniflux/model"
) )
// FeedIcon represents the feed icon response.
type FeedIcon struct {
ID int64 `json:"id"`
MimeType string `json:"mime_type"`
Data string `json:"data"`
}
// EntriesResponse represents the response sent when fetching entries. // EntriesResponse represents the response sent when fetching entries.
type EntriesResponse struct { type EntriesResponse struct {
Total int `json:"total"` Total int `json:"total"`

View file

@ -67,6 +67,7 @@ func getRoutes(cfg *config.Config, store *storage.Storage, feedHandler *feed.Han
router.Handle("/v1/feeds/{feedID}", apiHandler.Use(apiController.GetFeed)).Methods("GET") router.Handle("/v1/feeds/{feedID}", apiHandler.Use(apiController.GetFeed)).Methods("GET")
router.Handle("/v1/feeds/{feedID}", apiHandler.Use(apiController.UpdateFeed)).Methods("PUT") router.Handle("/v1/feeds/{feedID}", apiHandler.Use(apiController.UpdateFeed)).Methods("PUT")
router.Handle("/v1/feeds/{feedID}", apiHandler.Use(apiController.RemoveFeed)).Methods("DELETE") router.Handle("/v1/feeds/{feedID}", apiHandler.Use(apiController.RemoveFeed)).Methods("DELETE")
router.Handle("/v1/feeds/{feedID}/icon", apiHandler.Use(apiController.FeedIcon)).Methods("GET")
router.Handle("/v1/feeds/{feedID}/entries", apiHandler.Use(apiController.GetFeedEntries)).Methods("GET") router.Handle("/v1/feeds/{feedID}/entries", apiHandler.Use(apiController.GetFeedEntries)).Methods("GET")
router.Handle("/v1/feeds/{feedID}/entries/{entryID}", apiHandler.Use(apiController.GetEntry)).Methods("GET") router.Handle("/v1/feeds/{feedID}/entries/{entryID}", apiHandler.Use(apiController.GetEntry)).Methods("GET")

View file

@ -38,6 +38,28 @@ func (s *Storage) IconByID(iconID int64) (*model.Icon, error) {
return &icon, nil return &icon, nil
} }
// IconByFeedID returns a feed icon.
func (s *Storage) IconByFeedID(userID, feedID int64) (*model.Icon, error) {
defer helper.ExecutionTime(time.Now(), fmt.Sprintf("[Storage:IconByFeedID] userID=%d, feedID=%d", userID, feedID))
query := `
SELECT
icons.id, icons.hash, icons.mime_type, icons.content
FROM icons
LEFT JOIN feed_icons ON feed_icons.icon_id=icons.id
LEFT JOIN feeds ON feeds.id=feed_icons.feed_id
WHERE feeds.user_id=$1 AND feeds.id=$2
LIMIT 1
`
var icon model.Icon
err := s.db.QueryRow(query, userID, feedID).Scan(&icon.ID, &icon.Hash, &icon.MimeType, &icon.Content)
if err != nil {
return nil, fmt.Errorf("unable to fetch icon: %v", err)
}
return &icon, nil
}
// IconByHash returns an icon by the hash (checksum). // IconByHash returns an icon by the hash (checksum).
func (s *Storage) IconByHash(icon *model.Icon) error { func (s *Storage) IconByHash(icon *model.Icon) error {
defer helper.ExecutionTime(time.Now(), "[Storage:IconByHash]") defer helper.ExecutionTime(time.Now(), "[Storage:IconByHash]")

View file

@ -274,6 +274,23 @@ func (c *Client) DeleteFeed(feedID int64) error {
return nil return nil
} }
// FeedIcon gets a feed icon.
func (c *Client) FeedIcon(feedID int64) (*FeedIcon, error) {
body, err := c.request.Get(fmt.Sprintf("/v1/feeds/%d/icon", feedID))
if err != nil {
return nil, err
}
defer body.Close()
var feedIcon *FeedIcon
decoder := json.NewDecoder(body)
if err := decoder.Decode(&feedIcon); err != nil {
return nil, fmt.Errorf("miniflux: response error (%v)", err)
}
return feedIcon, nil
}
// Entry gets a single feed entry. // Entry gets a single feed entry.
func (c *Client) Entry(feedID, entryID int64) (*Entry, error) { func (c *Client) Entry(feedID, entryID int64) (*Entry, error) {
body, err := c.request.Get(fmt.Sprintf("/v1/feeds/%d/entries/%d", feedID, entryID)) body, err := c.request.Get(fmt.Sprintf("/v1/feeds/%d/entries/%d", feedID, entryID))

View file

@ -79,6 +79,13 @@ type Feed struct {
Entries Entries `json:"entries,omitempty"` Entries Entries `json:"entries,omitempty"`
} }
// FeedIcon represents the feed icon.
type FeedIcon struct {
ID int64 `json:"id"`
MimeType string `json:"mime_type"`
Data string `json:"data"`
}
// Feeds represents a list of feeds. // Feeds represents a list of feeds.
type Feeds []*Feed type Feeds []*Feed

View file

@ -8,7 +8,7 @@
--- ---
Minify is a minifier package written in [Go][1]. It provides HTML5, CSS3, JS, JSON, SVG and XML minifiers and an interface to implement any other minifier. Minification is the process of removing bytes from a file (such as whitespace) without changing its output and therefore shrinking its size and speeding up transmission over the internet and possibly parsing. The implemented minifiers are high performance and streaming, which implies O(n). Minify is a minifier package written in [Go][1]. It provides HTML5, CSS3, JS, JSON, SVG and XML minifiers and an interface to implement any other minifier. Minification is the process of removing bytes from a file (such as whitespace) without changing its output and therefore shrinking its size and speeding up transmission over the internet and possibly parsing. The implemented minifiers are designed for high performance.
The core functionality associates mimetypes with minification functions, allowing embedded resources (like CSS or JS within HTML files) to be minified as well. Users can add new implementations that are triggered based on a mimetype (or pattern), or redirect to an external command (like ClosureCompiler, UglifyCSS, ...). The core functionality associates mimetypes with minification functions, allowing embedded resources (like CSS or JS within HTML files) to be minified as well. Users can add new implementations that are triggered based on a mimetype (or pattern), or redirect to an external command (like ClosureCompiler, UglifyCSS, ...).
@ -100,51 +100,52 @@ The benchmarks directory contains a number of standardized samples used to compa
``` ```
name time/op name time/op
CSS/sample_bootstrap.css-4 3.05ms ± 1% CSS/sample_bootstrap.css-4 2.26ms ± 0%
CSS/sample_gumby.css-4 4.25ms ± 1% CSS/sample_gumby.css-4 2.92ms ± 1%
HTML/sample_amazon.html-4 3.33ms ± 0% HTML/sample_amazon.html-4 2.33ms ± 2%
HTML/sample_bbc.html-4 1.39ms ± 7% HTML/sample_bbc.html-4 1.02ms ± 1%
HTML/sample_blogpost.html-4 222µs ± 1% HTML/sample_blogpost.html-4 171µs ± 2%
HTML/sample_es6.html-4 18.0ms ± 1% HTML/sample_es6.html-4 14.5ms ± 0%
HTML/sample_stackoverflow.html-4 3.08ms ± 1% HTML/sample_stackoverflow.html-4 2.41ms ± 1%
HTML/sample_wikipedia.html-4 6.06ms ± 1% HTML/sample_wikipedia.html-4 4.76ms ± 0%
JS/sample_ace.js-4 9.92ms ± 1% JS/sample_ace.js-4 7.41ms ± 0%
JS/sample_dot.js-4 91.4µs ± 4% JS/sample_dot.js-4 63.7µs ± 0%
JS/sample_jquery.js-4 4.00ms ± 1% JS/sample_jquery.js-4 2.99ms ± 0%
JS/sample_jqueryui.js-4 7.93ms ± 0% JS/sample_jqueryui.js-4 5.92ms ± 2%
JS/sample_moment.js-4 1.46ms ± 1% JS/sample_moment.js-4 1.09ms ± 1%
JSON/sample_large.json-4 5.07ms ± 4% JSON/sample_large.json-4 2.95ms ± 0%
JSON/sample_testsuite.json-4 2.96ms ± 0% JSON/sample_testsuite.json-4 1.51ms ± 1%
JSON/sample_twitter.json-4 11.3µs ± 0% JSON/sample_twitter.json-4 6.75µs ± 1%
SVG/sample_arctic.svg-4 64.7ms ± 0% SVG/sample_arctic.svg-4 62.3ms ± 1%
SVG/sample_gopher.svg-4 227µs ± 0% SVG/sample_gopher.svg-4 218µs ± 0%
SVG/sample_usa.svg-4 35.9ms ± 6% SVG/sample_usa.svg-4 33.1ms ± 3%
XML/sample_books.xml-4 48.1µs ± 4% XML/sample_books.xml-4 36.2µs ± 0%
XML/sample_catalog.xml-4 20.2µs ± 0% XML/sample_catalog.xml-4 14.9µs ± 0%
XML/sample_omg.xml-4 9.02ms ± 0% XML/sample_omg.xml-4 6.31ms ± 1%
name speed name speed
CSS/sample_bootstrap.css-4 45.0MB/s ± 1% CSS/sample_bootstrap.css-4 60.8MB/s ± 0%
CSS/sample_gumby.css-4 43.8MB/s ± 1% CSS/sample_gumby.css-4 63.9MB/s ± 1%
HTML/sample_amazon.html-4 142MB/s ± 0% HTML/sample_amazon.html-4 203MB/s ± 2%
HTML/sample_bbc.html-4 83.0MB/s ± 7% HTML/sample_bbc.html-4 113MB/s ± 1%
HTML/sample_blogpost.html-4 94.5MB/s ± 1% HTML/sample_blogpost.html-4 123MB/s ± 2%
HTML/sample_es6.html-4 56.8MB/s ± 1% HTML/sample_es6.html-4 70.7MB/s ± 0%
HTML/sample_stackoverflow.html-4 66.7MB/s ± 1% HTML/sample_stackoverflow.html-4 85.2MB/s ± 1%
HTML/sample_wikipedia.html-4 73.5MB/s ± 1% HTML/sample_wikipedia.html-4 93.6MB/s ± 0%
JS/sample_ace.js-4 64.9MB/s ± 1% JS/sample_ace.js-4 86.9MB/s ± 0%
JS/sample_dot.js-4 56.4MB/s ± 4% JS/sample_dot.js-4 81.0MB/s ± 0%
JS/sample_jquery.js-4 61.8MB/s ± 1% JS/sample_jquery.js-4 82.8MB/s ± 0%
JS/sample_jqueryui.js-4 59.2MB/s ± 0% JS/sample_jqueryui.js-4 79.3MB/s ± 2%
JS/sample_moment.js-4 67.8MB/s ± 1% JS/sample_moment.js-4 91.2MB/s ± 1%
JSON/sample_large.json-4 150MB/s ± 4% JSON/sample_large.json-4 258MB/s ± 0%
JSON/sample_testsuite.json-4 233MB/s ± 0% JSON/sample_testsuite.json-4 457MB/s ± 1%
JSON/sample_twitter.json-4 134MB/s ± 0% JSON/sample_twitter.json-4 226MB/s ± 1%
SVG/sample_arctic.svg-4 22.7MB/s ± 0% SVG/sample_arctic.svg-4 23.6MB/s ± 1%
SVG/sample_gopher.svg-4 25.6MB/s ± 0% SVG/sample_gopher.svg-4 26.7MB/s ± 0%
SVG/sample_usa.svg-4 28.6MB/s ± 6% SVG/sample_usa.svg-4 30.9MB/s ± 3%
XML/sample_books.xml-4 92.1MB/s ± 4% XML/sample_books.xml-4 122MB/s ± 0%
XML/sample_catalog.xml-4 95.6MB/s ± 0% XML/sample_catalog.xml-4 130MB/s ± 0%
XML/sample_omg.xml-4 180MB/s ± 1%
``` ```
## HTML ## HTML

View file

@ -116,6 +116,11 @@ func TestHTML(t *testing.T) {
{`<meta e t n content=ful><a b`, `<meta e t n content=ful><a b>`}, {`<meta e t n content=ful><a b`, `<meta e t n content=ful><a b>`},
{`<img alt=a'b="">`, `<img alt='a&#39;b=""'>`}, {`<img alt=a'b="">`, `<img alt='a&#39;b=""'>`},
{`</b`, `</b`}, {`</b`, `</b`},
{`<title></`, `<title></`},
{`<svg <`, `<svg <`},
{`<svg "`, `<svg "`},
{`<svg></`, `<svg></`},
{`<script><!--<`, `<script><!--<`},
// bugs // bugs
{`<p>text</p><br>text`, `<p>text</p><br>text`}, // #122 {`<p>text</p><br>text`, `<p>text</p><br>text`}, // #122

View file

@ -40,6 +40,9 @@ func TestJS(t *testing.T) {
{"false\n\"string\"", "false\n\"string\""}, // #109 {"false\n\"string\"", "false\n\"string\""}, // #109
{"`\n", "`"}, // go fuzz {"`\n", "`"}, // go fuzz
{"a\n~b", "a\n~b"}, // #132 {"a\n~b", "a\n~b"}, // #132
// go-fuzz
{`/\`, `/\`},
} }
m := minify.New() m := minify.New()

View file

@ -43,6 +43,9 @@ func (o *Minifier) Minify(m *minify.M, w io.Writer, r io.Reader, _ map[string]st
for { for {
t := *tb.Shift() t := *tb.Shift()
if t.TokenType == xml.CDATAToken { if t.TokenType == xml.CDATAToken {
if len(t.Text) == 0 {
continue
}
if text, useText := xml.EscapeCDATAVal(&attrByteBuffer, t.Text); useText { if text, useText := xml.EscapeCDATAVal(&attrByteBuffer, t.Text); useText {
t.TokenType = xml.TextToken t.TokenType = xml.TextToken
t.Data = text t.Data = text

View file

@ -39,7 +39,10 @@ func TestXML(t *testing.T) {
{"<style>lala{color:red}</style>", "<style>lala{color:red}</style>"}, {"<style>lala{color:red}</style>", "<style>lala{color:red}</style>"},
{`cats and dogs `, `cats and dogs`}, {`cats and dogs `, `cats and dogs`},
{`</0`, `</0`}, // go fuzz // go fuzz
{`</0`, `</0`},
{`<!DOCTYPE`, `<!DOCTYPE`},
{`<![CDATA[`, ``},
} }
m := minify.New() m := minify.New()

View file

@ -81,9 +81,14 @@ func (z *Lexer) Restore() {
// Err returns the error returned from io.Reader or io.EOF when the end has been reached. // Err returns the error returned from io.Reader or io.EOF when the end has been reached.
func (z *Lexer) Err() error { func (z *Lexer) Err() error {
return z.PeekErr(0)
}
// PeekErr returns the error at position pos. When pos is zero, this is the same as calling Err().
func (z *Lexer) PeekErr(pos int) error {
if z.err != nil { if z.err != nil {
return z.err return z.err
} else if z.pos >= len(z.buf)-1 { } else if z.pos+pos >= len(z.buf)-1 {
return io.EOF return io.EOF
} }
return nil return nil

View file

@ -174,7 +174,8 @@ func TestParseError(t *testing.T) {
if tt.col == 0 { if tt.col == 0 {
test.T(t, p.Err(), io.EOF) test.T(t, p.Err(), io.EOF)
} else if perr, ok := p.Err().(*parse.Error); ok { } else if perr, ok := p.Err().(*parse.Error); ok {
test.T(t, perr.Col, tt.col) _, col, _ := perr.Position()
test.T(t, col, tt.col)
} else { } else {
test.Fail(t, "bad error:", p.Err()) test.Fail(t, "bad error:", p.Err())
} }

View file

@ -7,29 +7,43 @@ import (
"github.com/tdewolff/parse/buffer" "github.com/tdewolff/parse/buffer"
) )
// Error is a parsing error returned by parser. It contains a message and an offset at which the error occurred.
type Error struct { type Error struct {
Message string Message string
Line int r io.Reader
Col int Offset int
Context string line int
column int
context string
} }
// NewError creates a new error
func NewError(msg string, r io.Reader, offset int) *Error { func NewError(msg string, r io.Reader, offset int) *Error {
line, col, context, _ := Position(r, offset)
return &Error{ return &Error{
msg, Message: msg,
line, r: r,
col, Offset: offset,
context,
} }
} }
// NewErrorLexer creates a new error from a *buffer.Lexer
func NewErrorLexer(msg string, l *buffer.Lexer) *Error { func NewErrorLexer(msg string, l *buffer.Lexer) *Error {
r := buffer.NewReader(l.Bytes()) r := buffer.NewReader(l.Bytes())
offset := l.Offset() offset := l.Offset()
return NewError(msg, r, offset) return NewError(msg, r, offset)
} }
func (e *Error) Error() string { // Positions re-parses the file to determine the line, column, and context of the error.
return fmt.Sprintf("parse error:%d:%d: %s\n%s", e.Line, e.Col, e.Message, e.Context) // Context is the entire line at which the error occurred.
func (e *Error) Position() (int, int, string) {
if e.line == 0 {
e.line, e.column, e.context, _ = Position(e.r, e.Offset)
}
return e.line, e.column, e.context
}
// Error returns the error string, containing the context and line + column number.
func (e *Error) Error() string {
line, column, context := e.Position()
return fmt.Sprintf("parse error:%d:%d: %s\n%s", line, column, e.Message, context)
} }

View file

@ -79,10 +79,10 @@ func NewLexer(r io.Reader) *Lexer {
// Err returns the error encountered during lexing, this is often io.EOF but also other errors can be returned. // Err returns the error encountered during lexing, this is often io.EOF but also other errors can be returned.
func (l *Lexer) Err() error { func (l *Lexer) Err() error {
if err := l.r.Err(); err != nil { if l.err != nil {
return err return l.err
} }
return l.err return l.r.Err()
} }
// Restore restores the NULL byte at the end of the buffer. // Restore restores the NULL byte at the end of the buffer.
@ -103,8 +103,7 @@ func (l *Lexer) Next() (TokenType, []byte) {
} }
break break
} }
if c == 0 { if c == 0 && l.r.Err() != nil {
l.err = parse.NewErrorLexer("unexpected null character", l.r)
return ErrorToken, nil return ErrorToken, nil
} else if c != '>' && (c != '/' || l.r.Peek(1) != '>') { } else if c != '>' && (c != '/' || l.r.Peek(1) != '>') {
return AttributeToken, l.shiftAttribute() return AttributeToken, l.shiftAttribute()
@ -133,13 +132,16 @@ func (l *Lexer) Next() (TokenType, []byte) {
c = l.r.Peek(0) c = l.r.Peek(0)
if c == '<' { if c == '<' {
c = l.r.Peek(1) c = l.r.Peek(1)
isEndTag := c == '/' && l.r.Peek(2) != '>' && (l.r.Peek(2) != 0 || l.r.PeekErr(2) == nil)
if l.r.Pos() > 0 { if l.r.Pos() > 0 {
if c == '/' && l.r.Peek(2) != 0 || 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '!' || c == '?' { if isEndTag || 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '!' || c == '?' {
// return currently buffered texttoken so that we can return tag next iteration
return TextToken, l.r.Shift() return TextToken, l.r.Shift()
} }
} else if c == '/' && l.r.Peek(2) != 0 { } else if isEndTag {
l.r.Move(2) l.r.Move(2)
if c = l.r.Peek(0); c != '>' && !('a' <= c && c <= 'z' || 'A' <= c && c <= 'Z') { // only endtags that are not followed by > or EOF arrive here
if c = l.r.Peek(0); !('a' <= c && c <= 'z' || 'A' <= c && c <= 'Z') {
return CommentToken, l.shiftBogusComment() return CommentToken, l.shiftBogusComment()
} }
return EndTagToken, l.shiftEndTag() return EndTagToken, l.shiftEndTag()
@ -154,11 +156,10 @@ func (l *Lexer) Next() (TokenType, []byte) {
l.r.Move(1) l.r.Move(1)
return CommentToken, l.shiftBogusComment() return CommentToken, l.shiftBogusComment()
} }
} else if c == 0 { } else if c == 0 && l.r.Err() != nil {
if l.r.Pos() > 0 { if l.r.Pos() > 0 {
return TextToken, l.r.Shift() return TextToken, l.r.Shift()
} }
l.err = parse.NewErrorLexer("unexpected null character", l.r)
return ErrorToken, nil return ErrorToken, nil
} }
l.r.Move(1) l.r.Move(1)
@ -182,7 +183,7 @@ func (l *Lexer) AttrVal() []byte {
func (l *Lexer) shiftRawText() []byte { func (l *Lexer) shiftRawText() []byte {
if l.rawTag == Plaintext { if l.rawTag == Plaintext {
for { for {
if l.r.Peek(0) == 0 { if l.r.Peek(0) == 0 && l.r.Err() != nil {
return l.r.Shift() return l.r.Shift()
} }
l.r.Move(1) l.r.Move(1)
@ -237,15 +238,16 @@ func (l *Lexer) shiftRawText() []byte {
inScript = false inScript = false
} }
} }
} else if c == 0 { } else if c == 0 && l.r.Err() != nil {
return l.r.Shift() return l.r.Shift()
} else {
l.r.Move(1)
} }
l.r.Move(1)
} }
} else { } else {
l.r.Move(1) l.r.Move(1)
} }
} else if c == 0 { } else if c == 0 && l.r.Err() != nil {
return l.r.Shift() return l.r.Shift()
} else { } else {
l.r.Move(1) l.r.Move(1)
@ -258,7 +260,7 @@ func (l *Lexer) readMarkup() (TokenType, []byte) {
if l.at('-', '-') { if l.at('-', '-') {
l.r.Move(2) l.r.Move(2)
for { for {
if l.r.Peek(0) == 0 { if l.r.Peek(0) == 0 && l.r.Err() != nil {
return CommentToken, l.r.Shift() return CommentToken, l.r.Shift()
} else if l.at('-', '-', '>') { } else if l.at('-', '-', '>') {
l.text = l.r.Lexeme()[4:] l.text = l.r.Lexeme()[4:]
@ -274,7 +276,7 @@ func (l *Lexer) readMarkup() (TokenType, []byte) {
} else if l.at('[', 'C', 'D', 'A', 'T', 'A', '[') { } else if l.at('[', 'C', 'D', 'A', 'T', 'A', '[') {
l.r.Move(7) l.r.Move(7)
for { for {
if l.r.Peek(0) == 0 { if l.r.Peek(0) == 0 && l.r.Err() != nil {
return TextToken, l.r.Shift() return TextToken, l.r.Shift()
} else if l.at(']', ']', '>') { } else if l.at(']', ']', '>') {
l.r.Move(3) l.r.Move(3)
@ -289,7 +291,7 @@ func (l *Lexer) readMarkup() (TokenType, []byte) {
l.r.Move(1) l.r.Move(1)
} }
for { for {
if c := l.r.Peek(0); c == '>' || c == 0 { if c := l.r.Peek(0); c == '>' || c == 0 && l.r.Err() != nil {
l.text = l.r.Lexeme()[9:] l.text = l.r.Lexeme()[9:]
if c == '>' { if c == '>' {
l.r.Move(1) l.r.Move(1)
@ -310,7 +312,7 @@ func (l *Lexer) shiftBogusComment() []byte {
l.text = l.r.Lexeme()[2:] l.text = l.r.Lexeme()[2:]
l.r.Move(1) l.r.Move(1)
return l.r.Shift() return l.r.Shift()
} else if c == 0 { } else if c == 0 && l.r.Err() != nil {
l.text = l.r.Lexeme()[2:] l.text = l.r.Lexeme()[2:]
return l.r.Shift() return l.r.Shift()
} }
@ -320,19 +322,25 @@ func (l *Lexer) shiftBogusComment() []byte {
func (l *Lexer) shiftStartTag() (TokenType, []byte) { func (l *Lexer) shiftStartTag() (TokenType, []byte) {
for { for {
if c := l.r.Peek(0); c == ' ' || c == '>' || c == '/' && l.r.Peek(1) == '>' || c == '\t' || c == '\n' || c == '\r' || c == '\f' || c == 0 { if c := l.r.Peek(0); c == ' ' || c == '>' || c == '/' && l.r.Peek(1) == '>' || c == '\t' || c == '\n' || c == '\r' || c == '\f' || c == 0 && l.r.Err() != nil {
break break
} }
l.r.Move(1) l.r.Move(1)
} }
l.text = parse.ToLower(l.r.Lexeme()[1:]) l.text = parse.ToLower(l.r.Lexeme()[1:])
if h := ToHash(l.text); h == Textarea || h == Title || h == Style || h == Xmp || h == Iframe || h == Script || h == Plaintext || h == Svg || h == Math { if h := ToHash(l.text); h == Textarea || h == Title || h == Style || h == Xmp || h == Iframe || h == Script || h == Plaintext || h == Svg || h == Math {
if h == Svg { if h == Svg || h == Math {
data := l.shiftXml(h)
if l.err != nil {
return ErrorToken, nil
}
l.inTag = false l.inTag = false
return SvgToken, l.shiftXml(h) if h == Svg {
} else if h == Math { return SvgToken, data
l.inTag = false } else {
return MathToken, l.shiftXml(h) return MathToken, data
}
} }
l.rawTag = h l.rawTag = h
} }
@ -343,7 +351,7 @@ func (l *Lexer) shiftAttribute() []byte {
nameStart := l.r.Pos() nameStart := l.r.Pos()
var c byte var c byte
for { // attribute name state for { // attribute name state
if c = l.r.Peek(0); c == ' ' || c == '=' || c == '>' || c == '/' && l.r.Peek(1) == '>' || c == '\t' || c == '\n' || c == '\r' || c == '\f' || c == 0 { if c = l.r.Peek(0); c == ' ' || c == '=' || c == '>' || c == '/' && l.r.Peek(1) == '>' || c == '\t' || c == '\n' || c == '\r' || c == '\f' || c == 0 && l.r.Err() != nil {
break break
} }
l.r.Move(1) l.r.Move(1)
@ -374,14 +382,14 @@ func (l *Lexer) shiftAttribute() []byte {
if c == delim { if c == delim {
l.r.Move(1) l.r.Move(1)
break break
} else if c == 0 { } else if c == 0 && l.r.Err() != nil {
break break
} }
l.r.Move(1) l.r.Move(1)
} }
} else { // attribute value unquoted state } else { // attribute value unquoted state
for { for {
if c := l.r.Peek(0); c == ' ' || c == '>' || c == '\t' || c == '\n' || c == '\r' || c == '\f' || c == 0 { if c := l.r.Peek(0); c == ' ' || c == '>' || c == '\t' || c == '\n' || c == '\r' || c == '\f' || c == 0 && l.r.Err() != nil {
break break
} }
l.r.Move(1) l.r.Move(1)
@ -403,7 +411,7 @@ func (l *Lexer) shiftEndTag() []byte {
l.text = l.r.Lexeme()[2:] l.text = l.r.Lexeme()[2:]
l.r.Move(1) l.r.Move(1)
break break
} else if c == 0 { } else if c == 0 && l.r.Err() != nil {
l.text = l.r.Lexeme()[2:] l.text = l.r.Lexeme()[2:]
break break
} }
@ -422,6 +430,8 @@ func (l *Lexer) shiftEndTag() []byte {
return parse.ToLower(l.r.Shift()) return parse.ToLower(l.r.Shift())
} }
// shiftXml parses the content of a svg or math tag according to the XML 1.1 specifications, including the tag itself.
// So far we have already parsed `<svg` or `<math`.
func (l *Lexer) shiftXml(rawTag Hash) []byte { func (l *Lexer) shiftXml(rawTag Hash) []byte {
inQuote := false inQuote := false
for { for {
@ -429,26 +439,26 @@ func (l *Lexer) shiftXml(rawTag Hash) []byte {
if c == '"' { if c == '"' {
inQuote = !inQuote inQuote = !inQuote
l.r.Move(1) l.r.Move(1)
} else if c == '<' && !inQuote { } else if c == '<' && !inQuote && l.r.Peek(1) == '/' {
if l.r.Peek(1) == '/' { mark := l.r.Pos()
mark := l.r.Pos() l.r.Move(2)
l.r.Move(2) for {
for { if c = l.r.Peek(0); !('a' <= c && c <= 'z' || 'A' <= c && c <= 'Z') {
if c = l.r.Peek(0); !('a' <= c && c <= 'z' || 'A' <= c && c <= 'Z') {
break
}
l.r.Move(1)
}
if h := ToHash(parse.ToLower(parse.Copy(l.r.Lexeme()[mark+2:]))); h == rawTag { // copy so that ToLower doesn't change the case of the underlying slice
break break
} }
} else {
l.r.Move(1) l.r.Move(1)
} }
if h := ToHash(parse.ToLower(parse.Copy(l.r.Lexeme()[mark+2:]))); h == rawTag { // copy so that ToLower doesn't change the case of the underlying slice
break
}
} else if c == 0 { } else if c == 0 {
if l.r.Err() == nil {
l.err = parse.NewErrorLexer("unexpected null character", l.r)
}
return l.r.Shift() return l.r.Shift()
} else {
l.r.Move(1)
} }
l.r.Move(1)
} }
for { for {
@ -457,7 +467,10 @@ func (l *Lexer) shiftXml(rawTag Hash) []byte {
l.r.Move(1) l.r.Move(1)
break break
} else if c == 0 { } else if c == 0 {
break if l.r.Err() == nil {
l.err = parse.NewErrorLexer("unexpected null character", l.r)
}
return l.r.Shift()
} }
l.r.Move(1) l.r.Move(1)
} }

View file

@ -63,8 +63,22 @@ func TestTokens(t *testing.T) {
{"<script><!--", TTs{StartTagToken, StartTagCloseToken, TextToken}}, {"<script><!--", TTs{StartTagToken, StartTagCloseToken, TextToken}},
{"<script><!--var x='<script></script>';-->", TTs{StartTagToken, StartTagCloseToken, TextToken}}, {"<script><!--var x='<script></script>';-->", TTs{StartTagToken, StartTagCloseToken, TextToken}},
// NULL
{"foo\x00bar", TTs{TextToken}},
{"<\x00foo>", TTs{TextToken}},
{"<foo\x00>", TTs{StartTagToken, StartTagCloseToken}},
{"</\x00bogus>", TTs{CommentToken}},
{"</foo\x00>", TTs{EndTagToken}},
{"<plaintext>\x00</plaintext>", TTs{StartTagToken, StartTagCloseToken, TextToken}},
{"<script>\x00</script>", TTs{StartTagToken, StartTagCloseToken, TextToken, EndTagToken}},
{"<!--\x00-->", TTs{CommentToken}},
{"<![CDATA[\x00]]>", TTs{TextToken}},
{"<!doctype\x00>", TTs{DoctypeToken}},
{"<?bogus\x00>", TTs{CommentToken}},
{"<?bogus\x00>", TTs{CommentToken}},
// go-fuzz // go-fuzz
{"</>", TTs{EndTagToken}}, {"</>", TTs{TextToken}},
} }
for _, tt := range tokenTests { for _, tt := range tokenTests {
t.Run(tt.html, func(t *testing.T) { t.Run(tt.html, func(t *testing.T) {
@ -135,6 +149,11 @@ func TestAttributes(t *testing.T) {
{"<foo x", []string{"x", ""}}, {"<foo x", []string{"x", ""}},
{"<foo x=", []string{"x", ""}}, {"<foo x=", []string{"x", ""}},
{"<foo x='", []string{"x", "'"}}, {"<foo x='", []string{"x", "'"}},
// NULL
{"<foo \x00>", []string{"\x00", ""}},
{"<foo \x00=\x00>", []string{"\x00", "\x00"}},
{"<foo \x00='\x00'>", []string{"\x00", "'\x00'"}},
} }
for _, tt := range attributeTests { for _, tt := range attributeTests {
t.Run(tt.attr, func(t *testing.T) { t.Run(tt.attr, func(t *testing.T) {
@ -164,7 +183,8 @@ func TestErrors(t *testing.T) {
html string html string
col int col int
}{ }{
{"a\x00b", 2}, {"<svg>\x00</svg>", 6},
{"<svg></svg\x00>", 11},
} }
for _, tt := range errorTests { for _, tt := range errorTests {
t.Run(tt.html, func(t *testing.T) { t.Run(tt.html, func(t *testing.T) {
@ -175,7 +195,8 @@ func TestErrors(t *testing.T) {
if tt.col == 0 { if tt.col == 0 {
test.T(t, l.Err(), io.EOF) test.T(t, l.Err(), io.EOF)
} else if perr, ok := l.Err().(*parse.Error); ok { } else if perr, ok := l.Err().(*parse.Error); ok {
test.T(t, perr.Col, tt.col) _, col, _ := perr.Position()
test.T(t, col, tt.col)
} else { } else {
test.Fail(t, "bad error:", l.Err()) test.Fail(t, "bad error:", l.Err())
} }

View file

@ -599,6 +599,8 @@ func (l *Lexer) consumeRegexpToken() bool {
if l.consumeLineTerminator() { if l.consumeLineTerminator() {
l.r.Rewind(mark) l.r.Rewind(mark)
return false return false
} else if l.r.Peek(0) == 0 {
return true
} }
} else if l.consumeLineTerminator() { } else if l.consumeLineTerminator() {
l.r.Rewind(mark) l.r.Rewind(mark)

View file

@ -99,10 +99,10 @@ func NewParser(r io.Reader) *Parser {
// Err returns the error encountered during tokenization, this is often io.EOF but also other errors can be returned. // Err returns the error encountered during tokenization, this is often io.EOF but also other errors can be returned.
func (p *Parser) Err() error { func (p *Parser) Err() error {
if err := p.r.Err(); err != nil { if p.err != nil {
return err return p.err
} }
return p.err return p.r.Err()
} }
// Restore restores the NULL byte at the end of the buffer. // Restore restores the NULL byte at the end of the buffer.

View file

@ -93,7 +93,8 @@ func TestGrammarsError(t *testing.T) {
if tt.col == 0 { if tt.col == 0 {
test.T(t, p.Err(), io.EOF) test.T(t, p.Err(), io.EOF)
} else if perr, ok := p.Err().(*parse.Error); ok { } else if perr, ok := p.Err().(*parse.Error); ok {
test.T(t, perr.Col, tt.col) _, col, _ := perr.Position()
test.T(t, col, tt.col)
} else { } else {
test.Fail(t, "bad error:", p.Err()) test.Fail(t, "bad error:", p.Err())
} }

View file

@ -81,11 +81,10 @@ func NewLexer(r io.Reader) *Lexer {
// Err returns the error encountered during lexing, this is often io.EOF but also other errors can be returned. // Err returns the error encountered during lexing, this is often io.EOF but also other errors can be returned.
func (l *Lexer) Err() error { func (l *Lexer) Err() error {
err := l.r.Err() if l.err != nil {
if err != nil { return l.err
return err
} }
return l.err return l.r.Err()
} }
// Restore restores the NULL byte at the end of the buffer. // Restore restores the NULL byte at the end of the buffer.
@ -107,7 +106,9 @@ func (l *Lexer) Next() (TokenType, []byte) {
break break
} }
if c == 0 { if c == 0 {
l.err = parse.NewErrorLexer("unexpected null character", l.r) if l.r.Err() == nil {
l.err = parse.NewErrorLexer("unexpected null character", l.r)
}
return ErrorToken, nil return ErrorToken, nil
} else if c != '>' && (c != '/' && c != '?' || l.r.Peek(1) != '>') { } else if c != '>' && (c != '/' && c != '?' || l.r.Peek(1) != '>') {
return AttributeToken, l.shiftAttribute() return AttributeToken, l.shiftAttribute()
@ -148,7 +149,7 @@ func (l *Lexer) Next() (TokenType, []byte) {
l.r.Move(7) l.r.Move(7)
return CDATAToken, l.shiftCDATAText() return CDATAToken, l.shiftCDATAText()
} else if l.at('D', 'O', 'C', 'T', 'Y', 'P', 'E') { } else if l.at('D', 'O', 'C', 'T', 'Y', 'P', 'E') {
l.r.Move(8) l.r.Move(7)
return DOCTYPEToken, l.shiftDOCTYPEText() return DOCTYPEToken, l.shiftDOCTYPEText()
} }
l.r.Move(-2) l.r.Move(-2)
@ -164,7 +165,9 @@ func (l *Lexer) Next() (TokenType, []byte) {
if l.r.Pos() > 0 { if l.r.Pos() > 0 {
return TextToken, l.r.Shift() return TextToken, l.r.Shift()
} }
l.err = parse.NewErrorLexer("unexpected null character", l.r) if l.r.Err() == nil {
l.err = parse.NewErrorLexer("unexpected null character", l.r)
}
return ErrorToken, nil return ErrorToken, nil
} }
l.r.Move(1) l.r.Move(1)

View file

@ -155,6 +155,7 @@ func TestErrors(t *testing.T) {
col int col int
}{ }{
{"a\x00b", 2}, {"a\x00b", 2},
{"<a\x00>", 3},
} }
for _, tt := range errorTests { for _, tt := range errorTests {
t.Run(tt.xml, func(t *testing.T) { t.Run(tt.xml, func(t *testing.T) {
@ -165,7 +166,8 @@ func TestErrors(t *testing.T) {
if tt.col == 0 { if tt.col == 0 {
test.T(t, l.Err(), io.EOF) test.T(t, l.Err(), io.EOF)
} else if perr, ok := l.Err().(*parse.Error); ok { } else if perr, ok := l.Err().(*parse.Error); ok {
test.T(t, perr.Col, tt.col) _, col, _ := perr.Position()
test.T(t, col, tt.col)
} else { } else {
test.Fail(t, "bad error:", l.Err()) test.Fail(t, "bad error:", l.Err())
} }

View file

@ -1,8 +1,6 @@
language: go language: go
go: go:
- 1.4
- 1.5
- tip - tip
before_install: before_install:

View file

@ -1,12 +1,27 @@
a golang library that can get client's real public ip address from http request headers # RealIP
[![Build Status](https://travis-ci.org/tomasen/realip.svg?branch=master)](https://travis-ci.org/Tomasen/realip)
[![GoDoc](https://godoc.org/github.com/Tomasen/realip?status.svg)](http://godoc.org/github.com/Tomasen/realip) [![GoDoc](https://godoc.org/github.com/Tomasen/realip?status.svg)](http://godoc.org/github.com/Tomasen/realip)
Go package that can be used to get client's real public IP, which usually useful for logging HTTP server.
* follow the rule of X-FORWARDED-FOR/rfc7239 ### Feature
* follow the rule of X-Real-Ip
* lan/intranet IP address filtered * Follows the rule of X-Real-IP
* Follows the rule of X-Forwarded-For
* Exclude local or private address
## Example
```go
package main
import "github.com/Tomasen/realip"
func (h *Handler) ServeIndexPage(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
clientIP := realip.FromRequest(r)
log.Println("GET / from", clientIP)
}
```
## Developing ## Developing

View file

@ -1,7 +1,7 @@
package realip package realip
import ( import (
"log" "errors"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -10,62 +10,80 @@ import (
var cidrs []*net.IPNet var cidrs []*net.IPNet
func init() { func init() {
lancidrs := []string{ maxCidrBlocks := []string{
"127.0.0.1/8", "10.0.0.0/8", "169.254.0.0/16", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "127.0.0.1/8", // localhost
"10.0.0.0/8", // 24-bit block
"172.16.0.0/12", // 20-bit block
"192.168.0.0/16", // 16-bit block
"169.254.0.0/16", // link local address
"::1/128", // localhost IPv6
"fc00::/7", // unique local address IPv6
"fe80::/10", // link local address IPv6
} }
cidrs = make([]*net.IPNet, len(lancidrs)) cidrs = make([]*net.IPNet, len(maxCidrBlocks))
for i, maxCidrBlock := range maxCidrBlocks {
for i, it := range lancidrs { _, cidr, _ := net.ParseCIDR(maxCidrBlock)
_, cidrnet, err := net.ParseCIDR(it) cidrs[i] = cidr
if err != nil {
log.Fatalf("ParseCIDR error: %v", err) // assuming I did it right above
}
cidrs[i] = cidrnet
} }
} }
func isLocalAddress(addr string) bool { // isLocalAddress works by checking if the address is under private CIDR blocks.
// List of private CIDR blocks can be seen on :
//
// https://en.wikipedia.org/wiki/Private_network
//
// https://en.wikipedia.org/wiki/Link-local_address
func isPrivateAddress(address string) (bool, error) {
ipAddress := net.ParseIP(address)
if ipAddress == nil {
return false, errors.New("address is not valid")
}
for i := range cidrs { for i := range cidrs {
myaddr := net.ParseIP(addr) if cidrs[i].Contains(ipAddress) {
if cidrs[i].Contains(myaddr) { return true, nil
return true
} }
} }
return false return false, nil
} }
// Request.RemoteAddress contains port, which we want to remove i.e.: // FromRequest return client's real public IP address from http request headers.
// "[::1]:58292" => "[::1]" func FromRequest(r *http.Request) string {
func ipAddrFromRemoteAddr(s string) string { // Fetch header value
idx := strings.LastIndex(s, ":") xRealIP := r.Header.Get("X-Real-Ip")
if idx == -1 { xForwardedFor := r.Header.Get("X-Forwarded-For")
return s
// If both empty, return IP from remote address
if xRealIP == "" && xForwardedFor == "" {
var remoteIP string
// If there are colon in remote address, remove the port number
// otherwise, return remote address as is
if strings.ContainsRune(r.RemoteAddr, ':') {
remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr)
} else {
remoteIP = r.RemoteAddr
}
return remoteIP
} }
return s[:idx]
// Check list of IP in X-Forwarded-For and return the first global address
for _, address := range strings.Split(xForwardedFor, ",") {
address = strings.TrimSpace(address)
isPrivate, err := isPrivateAddress(address)
if !isPrivate && err == nil {
return address
}
}
// If nothing succeed, return X-Real-IP
return xRealIP
} }
// RealIP return client's real public IP address // RealIP is depreciated, use FromRequest instead
// from http request headers.
func RealIP(r *http.Request) string { func RealIP(r *http.Request) string {
hdr := r.Header return FromRequest(r)
hdrRealIP := hdr.Get("X-Real-Ip")
hdrForwardedFor := hdr.Get("X-Forwarded-For")
if len(hdrForwardedFor) == 0 && len(hdrRealIP) == 0 {
return ipAddrFromRemoteAddr(r.RemoteAddr)
}
// X-Forwarded-For is potentially a list of addresses separated with ","
for _, addr := range strings.Split(hdrForwardedFor, ",") {
// return first non-local address
addr = strings.TrimSpace(addr)
if len(addr) > 0 && !isLocalAddress(addr) {
return addr
}
}
return hdrRealIP
} }

View file

@ -2,11 +2,10 @@ package realip
import ( import (
"net/http" "net/http"
"strings"
"testing" "testing"
) )
func TestIsLocalAddr(t *testing.T) { func TestIsPrivateAddr(t *testing.T) {
testData := map[string]bool{ testData := map[string]bool{
"127.0.0.0": true, "127.0.0.0": true,
"10.0.0.0": true, "10.0.0.0": true,
@ -24,7 +23,12 @@ func TestIsLocalAddr(t *testing.T) {
} }
for addr, isLocal := range testData { for addr, isLocal := range testData {
if isLocalAddress(addr) != isLocal { isPrivate, err := isPrivateAddress(addr)
if err != nil {
t.Errorf("fail processing %s: %v", addr, err)
}
if isPrivate != isLocal {
format := "%s should " format := "%s should "
if !isLocal { if !isLocal {
format += "not " format += "not "
@ -36,51 +40,56 @@ func TestIsLocalAddr(t *testing.T) {
} }
} }
func TestIpAddrFromRemoteAddr(t *testing.T) {
testData := map[string]string{
"127.0.0.1:8888": "127.0.0.1",
"ip:port": "ip",
"ip": "ip",
"12:34::0": "12:34:",
}
for remoteAddr, expectedAddr := range testData {
if actualAddr := ipAddrFromRemoteAddr(remoteAddr); actualAddr != expectedAddr {
t.Errorf("ipAddrFromRemoteAddr of %s should be %s but get %s", remoteAddr, expectedAddr, actualAddr)
}
}
}
func TestRealIP(t *testing.T) { func TestRealIP(t *testing.T) {
newRequest := func(remoteAddr, hdrRealIP, hdrForwardedFor string) *http.Request { // Create type and function for testing
type testIP struct {
name string
request *http.Request
expected string
}
newRequest := func(remoteAddr, xRealIP string, xForwardedFor ...string) *http.Request {
h := http.Header{} h := http.Header{}
h["X-Real-Ip"] = []string{hdrRealIP} h.Set("X-Real-IP", xRealIP)
h["X-Forwarded-For"] = []string{hdrForwardedFor} for _, address := range xForwardedFor {
h.Set("X-Forwarded-For", address)
}
return &http.Request{ return &http.Request{
RemoteAddr: remoteAddr, RemoteAddr: remoteAddr,
Header: h, Header: h,
} }
} }
remoteAddr := "144.12.54.87" // Create test data
anotherRemoteAddr := "119.14.55.11" publicAddr1 := "144.12.54.87"
publicAddr2 := "119.14.55.11"
localAddr := "127.0.0.0" localAddr := "127.0.0.0"
testData := []struct { testData := []testIP{
expected string {
request *http.Request name: "No header",
}{ request: newRequest(publicAddr1, ""),
{remoteAddr, newRequest(remoteAddr, "", "")}, // no header expected: publicAddr1,
{remoteAddr, newRequest("", "", remoteAddr)}, // X-Forwarded-For: remoteAddr }, {
{remoteAddr, newRequest("", remoteAddr, "")}, // X-RealIP: remoteAddr name: "Has X-Forwarded-For",
request: newRequest("", "", publicAddr1),
// X-Forwarded-For: localAddr, remoteAddr, anotherRemoteAddr expected: publicAddr1,
{remoteAddr, newRequest("", "", strings.Join([]string{localAddr, remoteAddr, anotherRemoteAddr}, ", "))}, }, {
name: "Has multiple X-Forwarded-For",
request: newRequest("", "", localAddr, publicAddr1, publicAddr2),
expected: publicAddr2,
}, {
name: "Has X-Real-IP",
request: newRequest("", publicAddr1),
expected: publicAddr1,
},
} }
// Run test
for _, v := range testData { for _, v := range testData {
if actual := RealIP(v.request); v.expected != actual { if actual := FromRequest(v.request); v.expected != actual {
t.Errorf("expected %s but get %s", v.expected, actual) t.Errorf("%s: expected %s but get %s", v.name, v.expected, actual)
} }
} }
} }

View file

@ -946,7 +946,7 @@ func TestNonce_add(t *testing.T) {
c.addNonce(http.Header{"Replay-Nonce": {}}) c.addNonce(http.Header{"Replay-Nonce": {}})
c.addNonce(http.Header{"Replay-Nonce": {"nonce"}}) c.addNonce(http.Header{"Replay-Nonce": {"nonce"}})
nonces := map[string]struct{}{"nonce": struct{}{}} nonces := map[string]struct{}{"nonce": {}}
if !reflect.DeepEqual(c.nonces, nonces) { if !reflect.DeepEqual(c.nonces, nonces) {
t.Errorf("c.nonces = %q; want %q", c.nonces, nonces) t.Errorf("c.nonces = %q; want %q", c.nonces, nonces)
} }

View file

@ -241,11 +241,11 @@ func (p *hashed) Hash() []byte {
n = 3 n = 3
} }
arr[n] = '$' arr[n] = '$'
n += 1 n++
copy(arr[n:], []byte(fmt.Sprintf("%02d", p.cost))) copy(arr[n:], []byte(fmt.Sprintf("%02d", p.cost)))
n += 2 n += 2
arr[n] = '$' arr[n] = '$'
n += 1 n++
copy(arr[n:], p.salt) copy(arr[n:], p.salt)
n += encodedSaltSize n += encodedSaltSize
copy(arr[n:], p.hash) copy(arr[n:], p.hash)

View file

@ -49,8 +49,8 @@ func RandomG1(r io.Reader) (*big.Int, *G1, error) {
return k, new(G1).ScalarBaseMult(k), nil return k, new(G1).ScalarBaseMult(k), nil
} }
func (g *G1) String() string { func (e *G1) String() string {
return "bn256.G1" + g.p.String() return "bn256.G1" + e.p.String()
} }
// ScalarBaseMult sets e to g*k where g is the generator of the group and // ScalarBaseMult sets e to g*k where g is the generator of the group and
@ -92,11 +92,11 @@ func (e *G1) Neg(a *G1) *G1 {
} }
// Marshal converts n to a byte slice. // Marshal converts n to a byte slice.
func (n *G1) Marshal() []byte { func (e *G1) Marshal() []byte {
n.p.MakeAffine(nil) e.p.MakeAffine(nil)
xBytes := new(big.Int).Mod(n.p.x, p).Bytes() xBytes := new(big.Int).Mod(e.p.x, p).Bytes()
yBytes := new(big.Int).Mod(n.p.y, p).Bytes() yBytes := new(big.Int).Mod(e.p.y, p).Bytes()
// Each value is a 256-bit number. // Each value is a 256-bit number.
const numBytes = 256 / 8 const numBytes = 256 / 8
@ -166,8 +166,8 @@ func RandomG2(r io.Reader) (*big.Int, *G2, error) {
return k, new(G2).ScalarBaseMult(k), nil return k, new(G2).ScalarBaseMult(k), nil
} }
func (g *G2) String() string { func (e *G2) String() string {
return "bn256.G2" + g.p.String() return "bn256.G2" + e.p.String()
} }
// ScalarBaseMult sets e to g*k where g is the generator of the group and // ScalarBaseMult sets e to g*k where g is the generator of the group and

View file

@ -47,7 +47,7 @@ func Sum(m []byte, key *[KeySize]byte) *[Size]byte {
// Verify checks that digest is a valid authenticator of message m under the // Verify checks that digest is a valid authenticator of message m under the
// given secret key. Verify does not leak timing information. // given secret key. Verify does not leak timing information.
func Verify(digest []byte, m []byte, key *[32]byte) bool { func Verify(digest []byte, m []byte, key *[KeySize]byte) bool {
if len(digest) != Size { if len(digest) != Size {
return false return false
} }

View file

@ -760,7 +760,7 @@ func CreateResponse(issuer, responderCert *x509.Certificate, template Response,
} }
if template.Certificate != nil { if template.Certificate != nil {
response.Certificates = []asn1.RawValue{ response.Certificates = []asn1.RawValue{
asn1.RawValue{FullBytes: template.Certificate.Raw}, {FullBytes: template.Certificate.Raw},
} }
} }
responseDER, err := asn1.Marshal(response) responseDER, err := asn1.Marshal(response)

View file

@ -218,7 +218,7 @@ func TestOCSPResponse(t *testing.T) {
extensionBytes, _ := hex.DecodeString(ocspExtensionValueHex) extensionBytes, _ := hex.DecodeString(ocspExtensionValueHex)
extensions := []pkix.Extension{ extensions := []pkix.Extension{
pkix.Extension{ {
Id: ocspExtensionOID, Id: ocspExtensionOID,
Critical: false, Critical: false,
Value: extensionBytes, Value: extensionBytes,

View file

@ -325,9 +325,8 @@ func ReadEntity(packets *packet.Reader) (*Entity, error) {
if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok { if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok {
packets.Unread(p) packets.Unread(p)
return nil, errors.StructuralError("first packet was not a public/private key") return nil, errors.StructuralError("first packet was not a public/private key")
} else {
e.PrimaryKey = &e.PrivateKey.PublicKey
} }
e.PrimaryKey = &e.PrivateKey.PublicKey
} }
if !e.PrimaryKey.PubKeyAlgo.CanSign() { if !e.PrimaryKey.PubKeyAlgo.CanSign() {

View file

@ -122,7 +122,6 @@ func (c *rc2Cipher) Encrypt(dst, src []byte) {
r3 = r3 + c.k[r2&63] r3 = r3 + c.k[r2&63]
for j <= 40 { for j <= 40 {
// mix r0 // mix r0
r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1)
r0 = rotl16(r0, 1) r0 = rotl16(r0, 1)
@ -151,7 +150,6 @@ func (c *rc2Cipher) Encrypt(dst, src []byte) {
r3 = r3 + c.k[r2&63] r3 = r3 + c.k[r2&63]
for j <= 60 { for j <= 60 {
// mix r0 // mix r0
r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1)
r0 = rotl16(r0, 1) r0 = rotl16(r0, 1)
@ -244,7 +242,6 @@ func (c *rc2Cipher) Decrypt(dst, src []byte) {
r0 = r0 - c.k[r3&63] r0 = r0 - c.k[r3&63]
for j >= 0 { for j >= 0 {
// unmix r3 // unmix r3
r3 = rotl16(r3, 16-5) r3 = rotl16(r3, 16-5)
r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0) r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0)

View file

@ -11,7 +11,6 @@ import (
) )
func TestEncryptDecrypt(t *testing.T) { func TestEncryptDecrypt(t *testing.T) {
// TODO(dgryski): add the rest of the test vectors from the RFC // TODO(dgryski): add the rest of the test vectors from the RFC
var tests = []struct { var tests = []struct {
key string key string

View file

@ -202,7 +202,7 @@ func TestSqueezing(t *testing.T) {
d1 := newShakeHash() d1 := newShakeHash()
d1.Write([]byte(testString)) d1.Write([]byte(testString))
var multiple []byte var multiple []byte
for _ = range ref { for range ref {
one := make([]byte, 1) one := make([]byte, 1)
d1.Read(one) d1.Read(one)
multiple = append(multiple, one...) multiple = append(multiple, one...)

View file

@ -98,7 +98,7 @@ const (
agentAddIdentity = 17 agentAddIdentity = 17
agentRemoveIdentity = 18 agentRemoveIdentity = 18
agentRemoveAllIdentities = 19 agentRemoveAllIdentities = 19
agentAddIdConstrained = 25 agentAddIDConstrained = 25
// 3.3 Key-type independent requests from client to agent // 3.3 Key-type independent requests from client to agent
agentAddSmartcardKey = 20 agentAddSmartcardKey = 20
@ -515,7 +515,7 @@ func (c *client) insertKey(s interface{}, comment string, constraints []byte) er
// if constraints are present then the message type needs to be changed. // if constraints are present then the message type needs to be changed.
if len(constraints) != 0 { if len(constraints) != 0 {
req[0] = agentAddIdConstrained req[0] = agentAddIDConstrained
} }
resp, err := c.call(req) resp, err := c.call(req)
@ -577,11 +577,11 @@ func (c *client) Add(key AddedKey) error {
constraints = append(constraints, agentConstrainConfirm) constraints = append(constraints, agentConstrainConfirm)
} }
if cert := key.Certificate; cert == nil { cert := key.Certificate
if cert == nil {
return c.insertKey(key.PrivateKey, key.Comment, constraints) return c.insertKey(key.PrivateKey, key.Comment, constraints)
} else {
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
} }
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
} }
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error { func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error {
@ -633,7 +633,7 @@ func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string
// if constraints are present then the message type needs to be changed. // if constraints are present then the message type needs to be changed.
if len(constraints) != 0 { if len(constraints) != 0 {
req[0] = agentAddIdConstrained req[0] = agentAddIDConstrained
} }
signer, err := ssh.NewSignerFromKey(s) signer, err := ssh.NewSignerFromKey(s)

View file

@ -148,7 +148,7 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
} }
return rep, nil return rep, nil
case agentAddIdConstrained, agentAddIdentity: case agentAddIDConstrained, agentAddIdentity:
return nil, s.insertIdentity(data) return nil, s.insertIdentity(data)
} }

View file

@ -343,7 +343,7 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial)
} }
for opt, _ := range cert.CriticalOptions { for opt := range cert.CriticalOptions {
// sourceAddressCriticalOption will be enforced by // sourceAddressCriticalOption will be enforced by
// serverAuthenticate // serverAuthenticate
if opt == sourceAddressCriticalOption { if opt == sourceAddressCriticalOption {

View file

@ -205,32 +205,32 @@ type channel struct {
// writePacket sends a packet. If the packet is a channel close, it updates // writePacket sends a packet. If the packet is a channel close, it updates
// sentClose. This method takes the lock c.writeMu. // sentClose. This method takes the lock c.writeMu.
func (c *channel) writePacket(packet []byte) error { func (ch *channel) writePacket(packet []byte) error {
c.writeMu.Lock() ch.writeMu.Lock()
if c.sentClose { if ch.sentClose {
c.writeMu.Unlock() ch.writeMu.Unlock()
return io.EOF return io.EOF
} }
c.sentClose = (packet[0] == msgChannelClose) ch.sentClose = (packet[0] == msgChannelClose)
err := c.mux.conn.writePacket(packet) err := ch.mux.conn.writePacket(packet)
c.writeMu.Unlock() ch.writeMu.Unlock()
return err return err
} }
func (c *channel) sendMessage(msg interface{}) error { func (ch *channel) sendMessage(msg interface{}) error {
if debugMux { if debugMux {
log.Printf("send(%d): %#v", c.mux.chanList.offset, msg) log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
} }
p := Marshal(msg) p := Marshal(msg)
binary.BigEndian.PutUint32(p[1:], c.remoteId) binary.BigEndian.PutUint32(p[1:], ch.remoteId)
return c.writePacket(p) return ch.writePacket(p)
} }
// WriteExtended writes data to a specific extended stream. These streams are // WriteExtended writes data to a specific extended stream. These streams are
// used, for example, for stderr. // used, for example, for stderr.
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if c.sentEOF { if ch.sentEOF {
return 0, io.EOF return 0, io.EOF
} }
// 1 byte message type, 4 bytes remoteId, 4 bytes data length // 1 byte message type, 4 bytes remoteId, 4 bytes data length
@ -241,16 +241,16 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
opCode = msgChannelExtendedData opCode = msgChannelExtendedData
} }
c.writeMu.Lock() ch.writeMu.Lock()
packet := c.packetPool[extendedCode] packet := ch.packetPool[extendedCode]
// We don't remove the buffer from packetPool, so // We don't remove the buffer from packetPool, so
// WriteExtended calls from different goroutines will be // WriteExtended calls from different goroutines will be
// flagged as errors by the race detector. // flagged as errors by the race detector.
c.writeMu.Unlock() ch.writeMu.Unlock()
for len(data) > 0 { for len(data) > 0 {
space := min(c.maxRemotePayload, len(data)) space := min(ch.maxRemotePayload, len(data))
if space, err = c.remoteWin.reserve(space); err != nil { if space, err = ch.remoteWin.reserve(space); err != nil {
return n, err return n, err
} }
if want := headerLength + space; uint32(cap(packet)) < want { if want := headerLength + space; uint32(cap(packet)) < want {
@ -262,13 +262,13 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
todo := data[:space] todo := data[:space]
packet[0] = opCode packet[0] = opCode
binary.BigEndian.PutUint32(packet[1:], c.remoteId) binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
if extendedCode > 0 { if extendedCode > 0 {
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
} }
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
copy(packet[headerLength:], todo) copy(packet[headerLength:], todo)
if err = c.writePacket(packet); err != nil { if err = ch.writePacket(packet); err != nil {
return n, err return n, err
} }
@ -276,14 +276,14 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
data = data[len(todo):] data = data[len(todo):]
} }
c.writeMu.Lock() ch.writeMu.Lock()
c.packetPool[extendedCode] = packet ch.packetPool[extendedCode] = packet
c.writeMu.Unlock() ch.writeMu.Unlock()
return n, err return n, err
} }
func (c *channel) handleData(packet []byte) error { func (ch *channel) handleData(packet []byte) error {
headerLen := 9 headerLen := 9
isExtendedData := packet[0] == msgChannelExtendedData isExtendedData := packet[0] == msgChannelExtendedData
if isExtendedData { if isExtendedData {
@ -303,7 +303,7 @@ func (c *channel) handleData(packet []byte) error {
if length == 0 { if length == 0 {
return nil return nil
} }
if length > c.maxIncomingPayload { if length > ch.maxIncomingPayload {
// TODO(hanwen): should send Disconnect? // TODO(hanwen): should send Disconnect?
return errors.New("ssh: incoming packet exceeds maximum payload size") return errors.New("ssh: incoming packet exceeds maximum payload size")
} }
@ -313,21 +313,21 @@ func (c *channel) handleData(packet []byte) error {
return errors.New("ssh: wrong packet length") return errors.New("ssh: wrong packet length")
} }
c.windowMu.Lock() ch.windowMu.Lock()
if c.myWindow < length { if ch.myWindow < length {
c.windowMu.Unlock() ch.windowMu.Unlock()
// TODO(hanwen): should send Disconnect with reason? // TODO(hanwen): should send Disconnect with reason?
return errors.New("ssh: remote side wrote too much") return errors.New("ssh: remote side wrote too much")
} }
c.myWindow -= length ch.myWindow -= length
c.windowMu.Unlock() ch.windowMu.Unlock()
if extended == 1 { if extended == 1 {
c.extPending.write(data) ch.extPending.write(data)
} else if extended > 0 { } else if extended > 0 {
// discard other extended data. // discard other extended data.
} else { } else {
c.pending.write(data) ch.pending.write(data)
} }
return nil return nil
} }
@ -384,31 +384,31 @@ func (c *channel) close() {
// responseMessageReceived is called when a success or failure message is // responseMessageReceived is called when a success or failure message is
// received on a channel to check that such a message is reasonable for the // received on a channel to check that such a message is reasonable for the
// given channel. // given channel.
func (c *channel) responseMessageReceived() error { func (ch *channel) responseMessageReceived() error {
if c.direction == channelInbound { if ch.direction == channelInbound {
return errors.New("ssh: channel response message received on inbound channel") return errors.New("ssh: channel response message received on inbound channel")
} }
if c.decided { if ch.decided {
return errors.New("ssh: duplicate response received for channel") return errors.New("ssh: duplicate response received for channel")
} }
c.decided = true ch.decided = true
return nil return nil
} }
func (c *channel) handlePacket(packet []byte) error { func (ch *channel) handlePacket(packet []byte) error {
switch packet[0] { switch packet[0] {
case msgChannelData, msgChannelExtendedData: case msgChannelData, msgChannelExtendedData:
return c.handleData(packet) return ch.handleData(packet)
case msgChannelClose: case msgChannelClose:
c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
c.mux.chanList.remove(c.localId) ch.mux.chanList.remove(ch.localId)
c.close() ch.close()
return nil return nil
case msgChannelEOF: case msgChannelEOF:
// RFC 4254 is mute on how EOF affects dataExt messages but // RFC 4254 is mute on how EOF affects dataExt messages but
// it is logical to signal EOF at the same time. // it is logical to signal EOF at the same time.
c.extPending.eof() ch.extPending.eof()
c.pending.eof() ch.pending.eof()
return nil return nil
} }
@ -419,24 +419,24 @@ func (c *channel) handlePacket(packet []byte) error {
switch msg := decoded.(type) { switch msg := decoded.(type) {
case *channelOpenFailureMsg: case *channelOpenFailureMsg:
if err := c.responseMessageReceived(); err != nil { if err := ch.responseMessageReceived(); err != nil {
return err return err
} }
c.mux.chanList.remove(msg.PeersId) ch.mux.chanList.remove(msg.PeersID)
c.msg <- msg ch.msg <- msg
case *channelOpenConfirmMsg: case *channelOpenConfirmMsg:
if err := c.responseMessageReceived(); err != nil { if err := ch.responseMessageReceived(); err != nil {
return err return err
} }
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
} }
c.remoteId = msg.MyId ch.remoteId = msg.MyID
c.maxRemotePayload = msg.MaxPacketSize ch.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.MyWindow) ch.remoteWin.add(msg.MyWindow)
c.msg <- msg ch.msg <- msg
case *windowAdjustMsg: case *windowAdjustMsg:
if !c.remoteWin.add(msg.AdditionalBytes) { if !ch.remoteWin.add(msg.AdditionalBytes) {
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
} }
case *channelRequestMsg: case *channelRequestMsg:
@ -444,12 +444,12 @@ func (c *channel) handlePacket(packet []byte) error {
Type: msg.Request, Type: msg.Request,
WantReply: msg.WantReply, WantReply: msg.WantReply,
Payload: msg.RequestSpecificData, Payload: msg.RequestSpecificData,
ch: c, ch: ch,
} }
c.incomingRequests <- &req ch.incomingRequests <- &req
default: default:
c.msg <- msg ch.msg <- msg
} }
return nil return nil
} }
@ -488,23 +488,23 @@ func (e *extChannel) Read(data []byte) (n int, err error) {
return e.ch.ReadExtended(data, e.code) return e.ch.ReadExtended(data, e.code)
} }
func (c *channel) Accept() (Channel, <-chan *Request, error) { func (ch *channel) Accept() (Channel, <-chan *Request, error) {
if c.decided { if ch.decided {
return nil, nil, errDecidedAlready return nil, nil, errDecidedAlready
} }
c.maxIncomingPayload = channelMaxPacket ch.maxIncomingPayload = channelMaxPacket
confirm := channelOpenConfirmMsg{ confirm := channelOpenConfirmMsg{
PeersId: c.remoteId, PeersID: ch.remoteId,
MyId: c.localId, MyID: ch.localId,
MyWindow: c.myWindow, MyWindow: ch.myWindow,
MaxPacketSize: c.maxIncomingPayload, MaxPacketSize: ch.maxIncomingPayload,
} }
c.decided = true ch.decided = true
if err := c.sendMessage(confirm); err != nil { if err := ch.sendMessage(confirm); err != nil {
return nil, nil, err return nil, nil, err
} }
return c, c.incomingRequests, nil return ch, ch.incomingRequests, nil
} }
func (ch *channel) Reject(reason RejectionReason, message string) error { func (ch *channel) Reject(reason RejectionReason, message string) error {
@ -512,7 +512,7 @@ func (ch *channel) Reject(reason RejectionReason, message string) error {
return errDecidedAlready return errDecidedAlready
} }
reject := channelOpenFailureMsg{ reject := channelOpenFailureMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
Reason: reason, Reason: reason,
Message: message, Message: message,
Language: "en", Language: "en",
@ -541,7 +541,7 @@ func (ch *channel) CloseWrite() error {
} }
ch.sentEOF = true ch.sentEOF = true
return ch.sendMessage(channelEOFMsg{ return ch.sendMessage(channelEOFMsg{
PeersId: ch.remoteId}) PeersID: ch.remoteId})
} }
func (ch *channel) Close() error { func (ch *channel) Close() error {
@ -550,7 +550,7 @@ func (ch *channel) Close() error {
} }
return ch.sendMessage(channelCloseMsg{ return ch.sendMessage(channelCloseMsg{
PeersId: ch.remoteId}) PeersID: ch.remoteId})
} }
// Extended returns an io.ReadWriter that sends and receives data on the given, // Extended returns an io.ReadWriter that sends and receives data on the given,
@ -577,7 +577,7 @@ func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (boo
} }
msg := channelRequestMsg{ msg := channelRequestMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
Request: name, Request: name,
WantReply: wantReply, WantReply: wantReply,
RequestSpecificData: payload, RequestSpecificData: payload,
@ -614,11 +614,11 @@ func (ch *channel) ackRequest(ok bool) error {
var msg interface{} var msg interface{}
if !ok { if !ok {
msg = channelRequestFailureMsg{ msg = channelRequestFailureMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
} }
} else { } else {
msg = channelRequestSuccessMsg{ msg = channelRequestSuccessMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
} }
} }
return ch.sendMessage(msg) return ch.sendMessage(msg)

View file

@ -372,7 +372,7 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
} }
length := binary.BigEndian.Uint32(c.prefix[:]) length := binary.BigEndian.Uint32(c.prefix[:])
if length > maxPacket { if length > maxPacket {
return nil, errors.New("ssh: max packet length exceeded.") return nil, errors.New("ssh: max packet length exceeded")
} }
if cap(c.buf) < int(length+gcmTagSize) { if cap(c.buf) < int(length+gcmTagSize) {
@ -548,11 +548,11 @@ func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error)
c.packetData = c.packetData[:entirePacketSize] c.packetData = c.packetData[:entirePacketSize]
} }
if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { n, err := io.ReadFull(r, c.packetData[firstBlockLength:])
if err != nil {
return nil, err return nil, err
} else {
c.oracleCamouflage -= uint32(n)
} }
c.oracleCamouflage -= uint32(n)
remainingCrypted := c.packetData[firstBlockLength:macStart] remainingCrypted := c.packetData[firstBlockLength:macStart]
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted)

View file

@ -89,7 +89,9 @@ func TestBannerCallback(t *testing.T) {
defer c2.Close() defer c2.Close()
serverConf := &ServerConfig{ serverConf := &ServerConfig{
NoClientAuth: true, PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
return &Permissions{}, nil
},
BannerCallback: func(conn ConnMetadata) string { BannerCallback: func(conn ConnMetadata) string {
return "Hello World" return "Hello World"
}, },
@ -98,10 +100,15 @@ func TestBannerCallback(t *testing.T) {
go NewServerConn(c1, serverConf) go NewServerConn(c1, serverConf)
var receivedBanner string var receivedBanner string
var bannerCount int
clientConf := ClientConfig{ clientConf := ClientConfig{
Auth: []AuthMethod{
Password("123"),
},
User: "user", User: "user",
HostKeyCallback: InsecureIgnoreHostKey(), HostKeyCallback: InsecureIgnoreHostKey(),
BannerCallback: func(message string) error { BannerCallback: func(message string) error {
bannerCount++
receivedBanner = message receivedBanner = message
return nil return nil
}, },
@ -112,6 +119,10 @@ func TestBannerCallback(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if bannerCount != 1 {
t.Errorf("got %d banners; want 1", bannerCount)
}
expected := "Hello World" expected := "Hello World"
if receivedBanner != expected { if receivedBanner != expected {
t.Fatalf("got %s; want %s", receivedBanner, expected) t.Fatalf("got %s; want %s", receivedBanner, expected)

View file

@ -242,7 +242,7 @@ func (c *Config) SetDefaults() {
// buildDataSignedForAuth returns the data that is signed in order to prove // buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7. // possession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
data := struct { data := struct {
Session []byte Session []byte
Type byte Type byte
@ -253,7 +253,7 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
Algo []byte Algo []byte
PubKey []byte PubKey []byte
}{ }{
sessionId, sessionID,
msgUserAuthRequest, msgUserAuthRequest,
req.User, req.User,
req.Service, req.Service,

View file

@ -119,7 +119,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
return nil, err return nil, err
} }
kInt, err := group.diffieHellman(kexDHReply.Y, x) ki, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -129,8 +129,8 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
writeString(h, kexDHReply.HostKey) writeString(h, kexDHReply.HostKey)
writeInt(h, X) writeInt(h, X)
writeInt(h, kexDHReply.Y) writeInt(h, kexDHReply.Y)
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
return &kexResult{ return &kexResult{
@ -164,7 +164,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
} }
Y := new(big.Int).Exp(group.g, y, group.p) Y := new(big.Int).Exp(group.g, y, group.p)
kInt, err := group.diffieHellman(kexDHInit.X, y) ki, err := group.diffieHellman(kexDHInit.X, y)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -177,8 +177,8 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
writeInt(h, kexDHInit.X) writeInt(h, kexDHInit.X)
writeInt(h, Y) writeInt(h, Y)
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
H := h.Sum(nil) H := h.Sum(nil)
@ -462,9 +462,9 @@ func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handsh
writeString(h, kp.pub[:]) writeString(h, kp.pub[:])
writeString(h, reply.EphemeralPubKey) writeString(h, reply.EphemeralPubKey)
kInt := new(big.Int).SetBytes(secret[:]) ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
return &kexResult{ return &kexResult{
@ -510,9 +510,9 @@ func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handsh
writeString(h, kexInit.ClientPubKey) writeString(h, kexInit.ClientPubKey)
writeString(h, kp.pub[:]) writeString(h, kp.pub[:])
kInt := new(big.Int).SetBytes(secret[:]) ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
H := h.Sum(nil) H := h.Sum(nil)

View file

@ -363,7 +363,7 @@ func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
type dsaPublicKey dsa.PublicKey type dsaPublicKey dsa.PublicKey
func (r *dsaPublicKey) Type() string { func (k *dsaPublicKey) Type() string {
return "ssh-dss" return "ssh-dss"
} }
@ -481,12 +481,12 @@ func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
type ecdsaPublicKey ecdsa.PublicKey type ecdsaPublicKey ecdsa.PublicKey
func (key *ecdsaPublicKey) Type() string { func (k *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + key.nistID() return "ecdsa-sha2-" + k.nistID()
} }
func (key *ecdsaPublicKey) nistID() string { func (k *ecdsaPublicKey) nistID() string {
switch key.Params().BitSize { switch k.Params().BitSize {
case 256: case 256:
return "nistp256" return "nistp256"
case 384: case 384:
@ -499,7 +499,7 @@ func (key *ecdsaPublicKey) nistID() string {
type ed25519PublicKey ed25519.PublicKey type ed25519PublicKey ed25519.PublicKey
func (key ed25519PublicKey) Type() string { func (k ed25519PublicKey) Type() string {
return KeyAlgoED25519 return KeyAlgoED25519
} }
@ -518,23 +518,23 @@ func parseED25519(in []byte) (out PublicKey, rest []byte, err error) {
return (ed25519PublicKey)(key), w.Rest, nil return (ed25519PublicKey)(key), w.Rest, nil
} }
func (key ed25519PublicKey) Marshal() []byte { func (k ed25519PublicKey) Marshal() []byte {
w := struct { w := struct {
Name string Name string
KeyBytes []byte KeyBytes []byte
}{ }{
KeyAlgoED25519, KeyAlgoED25519,
[]byte(key), []byte(k),
} }
return Marshal(&w) return Marshal(&w)
} }
func (key ed25519PublicKey) Verify(b []byte, sig *Signature) error { func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error {
if sig.Format != key.Type() { if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
} }
edKey := (ed25519.PublicKey)(key) edKey := (ed25519.PublicKey)(k)
if ok := ed25519.Verify(edKey, b, sig.Blob); !ok { if ok := ed25519.Verify(edKey, b, sig.Blob); !ok {
return errors.New("ssh: signature did not verify") return errors.New("ssh: signature did not verify")
} }
@ -595,9 +595,9 @@ func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
return (*ecdsaPublicKey)(key), w.Rest, nil return (*ecdsaPublicKey)(key), w.Rest, nil
} }
func (key *ecdsaPublicKey) Marshal() []byte { func (k *ecdsaPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1. // See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y)
// ECDSA publickey struct layout should match the struct used by // ECDSA publickey struct layout should match the struct used by
// parseECDSACert in the x/crypto/ssh/agent package. // parseECDSACert in the x/crypto/ssh/agent package.
w := struct { w := struct {
@ -605,20 +605,20 @@ func (key *ecdsaPublicKey) Marshal() []byte {
ID string ID string
Key []byte Key []byte
}{ }{
key.Type(), k.Type(),
key.nistID(), k.nistID(),
keyBytes, keyBytes,
} }
return Marshal(&w) return Marshal(&w)
} }
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != key.Type() { if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
} }
h := ecHash(key.Curve).New() h := ecHash(k.Curve).New()
h.Write(data) h.Write(data)
digest := h.Sum(nil) digest := h.Sum(nil)
@ -635,7 +635,7 @@ func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
return err return err
} }
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) {
return nil return nil
} }
return errors.New("ssh: signature did not verify") return errors.New("ssh: signature did not verify")
@ -758,7 +758,7 @@ func NewPublicKey(key interface{}) (PublicKey, error) {
return (*rsaPublicKey)(key), nil return (*rsaPublicKey)(key), nil
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
if !supportedEllipticCurve(key.Curve) { if !supportedEllipticCurve(key.Curve) {
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported.") return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported")
} }
return (*ecdsaPublicKey)(key), nil return (*ecdsaPublicKey)(key), nil
case *dsa.PublicKey: case *dsa.PublicKey:

View file

@ -108,8 +108,8 @@ func wildcardMatch(pat []byte, str []byte) bool {
} }
} }
func (l *hostPattern) match(a addr) bool { func (p *hostPattern) match(a addr) bool {
return wildcardMatch([]byte(l.addr.host), []byte(a.host)) && l.addr.port == a.port return wildcardMatch([]byte(p.addr.host), []byte(a.host)) && p.addr.port == a.port
} }
type keyDBLine struct { type keyDBLine struct {

View file

@ -162,7 +162,7 @@ const msgChannelOpen = 90
type channelOpenMsg struct { type channelOpenMsg struct {
ChanType string `sshtype:"90"` ChanType string `sshtype:"90"`
PeersId uint32 PeersID uint32
PeersWindow uint32 PeersWindow uint32
MaxPacketSize uint32 MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"` TypeSpecificData []byte `ssh:"rest"`
@ -173,7 +173,7 @@ const msgChannelData = 94
// Used for debug print outs of packets. // Used for debug print outs of packets.
type channelDataMsg struct { type channelDataMsg struct {
PeersId uint32 `sshtype:"94"` PeersID uint32 `sshtype:"94"`
Length uint32 Length uint32
Rest []byte `ssh:"rest"` Rest []byte `ssh:"rest"`
} }
@ -182,8 +182,8 @@ type channelDataMsg struct {
const msgChannelOpenConfirm = 91 const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct { type channelOpenConfirmMsg struct {
PeersId uint32 `sshtype:"91"` PeersID uint32 `sshtype:"91"`
MyId uint32 MyID uint32
MyWindow uint32 MyWindow uint32
MaxPacketSize uint32 MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"` TypeSpecificData []byte `ssh:"rest"`
@ -193,7 +193,7 @@ type channelOpenConfirmMsg struct {
const msgChannelOpenFailure = 92 const msgChannelOpenFailure = 92
type channelOpenFailureMsg struct { type channelOpenFailureMsg struct {
PeersId uint32 `sshtype:"92"` PeersID uint32 `sshtype:"92"`
Reason RejectionReason Reason RejectionReason
Message string Message string
Language string Language string
@ -202,7 +202,7 @@ type channelOpenFailureMsg struct {
const msgChannelRequest = 98 const msgChannelRequest = 98
type channelRequestMsg struct { type channelRequestMsg struct {
PeersId uint32 `sshtype:"98"` PeersID uint32 `sshtype:"98"`
Request string Request string
WantReply bool WantReply bool
RequestSpecificData []byte `ssh:"rest"` RequestSpecificData []byte `ssh:"rest"`
@ -212,28 +212,28 @@ type channelRequestMsg struct {
const msgChannelSuccess = 99 const msgChannelSuccess = 99
type channelRequestSuccessMsg struct { type channelRequestSuccessMsg struct {
PeersId uint32 `sshtype:"99"` PeersID uint32 `sshtype:"99"`
} }
// See RFC 4254, section 5.4. // See RFC 4254, section 5.4.
const msgChannelFailure = 100 const msgChannelFailure = 100
type channelRequestFailureMsg struct { type channelRequestFailureMsg struct {
PeersId uint32 `sshtype:"100"` PeersID uint32 `sshtype:"100"`
} }
// See RFC 4254, section 5.3 // See RFC 4254, section 5.3
const msgChannelClose = 97 const msgChannelClose = 97
type channelCloseMsg struct { type channelCloseMsg struct {
PeersId uint32 `sshtype:"97"` PeersID uint32 `sshtype:"97"`
} }
// See RFC 4254, section 5.3 // See RFC 4254, section 5.3
const msgChannelEOF = 96 const msgChannelEOF = 96
type channelEOFMsg struct { type channelEOFMsg struct {
PeersId uint32 `sshtype:"96"` PeersID uint32 `sshtype:"96"`
} }
// See RFC 4254, section 4 // See RFC 4254, section 4
@ -263,7 +263,7 @@ type globalRequestFailureMsg struct {
const msgChannelWindowAdjust = 93 const msgChannelWindowAdjust = 93
type windowAdjustMsg struct { type windowAdjustMsg struct {
PeersId uint32 `sshtype:"93"` PeersID uint32 `sshtype:"93"`
AdditionalBytes uint32 AdditionalBytes uint32
} }

View file

@ -278,7 +278,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
failMsg := channelOpenFailureMsg{ failMsg := channelOpenFailureMsg{
PeersId: msg.PeersId, PeersID: msg.PeersID,
Reason: ConnectionFailed, Reason: ConnectionFailed,
Message: "invalid request", Message: "invalid request",
Language: "en_US.UTF-8", Language: "en_US.UTF-8",
@ -287,7 +287,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
} }
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
c.remoteId = msg.PeersId c.remoteId = msg.PeersID
c.maxRemotePayload = msg.MaxPacketSize c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.PeersWindow) c.remoteWin.add(msg.PeersWindow)
m.incomingChannels <- c m.incomingChannels <- c
@ -313,7 +313,7 @@ func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
PeersWindow: ch.myWindow, PeersWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload, MaxPacketSize: ch.maxIncomingPayload,
TypeSpecificData: extra, TypeSpecificData: extra,
PeersId: ch.localId, PeersID: ch.localId,
} }
if err := m.sendMessage(open); err != nil { if err := m.sendMessage(open); err != nil {
return nil, err return nil, err

View file

@ -316,6 +316,7 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
authFailures := 0 authFailures := 0
var authErrs []error var authErrs []error
var displayedBanner bool
userAuthLoop: userAuthLoop:
for { for {
@ -348,7 +349,8 @@ userAuthLoop:
s.user = userAuthReq.User s.user = userAuthReq.User
if authFailures == 0 && config.BannerCallback != nil { if !displayedBanner && config.BannerCallback != nil {
displayedBanner = true
msg := config.BannerCallback(s) msg := config.BannerCallback(s)
if msg != "" { if msg != "" {
bannerMsg := &userAuthBannerMsg{ bannerMsg := &userAuthBannerMsg{

View file

@ -406,7 +406,7 @@ func (s *Session) Wait() error {
s.stdinPipeWriter.Close() s.stdinPipeWriter.Close()
} }
var copyError error var copyError error
for _ = range s.copyFuncs { for range s.copyFuncs {
if err := <-s.errors; err != nil && copyError == nil { if err := <-s.errors; err != nil && copyError == nil {
copyError = err copyError = err
} }

View file

@ -617,7 +617,7 @@ func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) {
if _, err = w.Write(crlf); err != nil { if _, err = w.Write(crlf); err != nil {
return n, err return n, err
} }
n += 1 n++
buf = buf[1:] buf = buf[1:]
} }
} }

View file

@ -8,7 +8,6 @@ package test
import ( import (
"testing" "testing"
) )
func TestBannerCallbackAgainstOpenSSH(t *testing.T) { func TestBannerCallbackAgainstOpenSSH(t *testing.T) {

View file

@ -2,6 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package contains integration tests for the // Package test contains integration tests for the
// golang.org/x/crypto/ssh package. // golang.org/x/crypto/ssh package.
package test // import "golang.org/x/crypto/ssh/test" package test // import "golang.org/x/crypto/ssh/test"

View file

@ -25,7 +25,7 @@ import (
"golang.org/x/crypto/ssh/testdata" "golang.org/x/crypto/ssh/testdata"
) )
const sshd_config = ` const sshdConfig = `
Protocol 2 Protocol 2
Banner {{.Dir}}/banner Banner {{.Dir}}/banner
HostKey {{.Dir}}/id_rsa HostKey {{.Dir}}/id_rsa
@ -51,7 +51,7 @@ HostbasedAuthentication no
PubkeyAcceptedKeyTypes=* PubkeyAcceptedKeyTypes=*
` `
var configTmpl = template.Must(template.New("").Parse(sshd_config)) var configTmpl = template.Must(template.New("").Parse(sshdConfig))
type server struct { type server struct {
t *testing.T t *testing.T
@ -271,7 +271,7 @@ func newServer(t *testing.T) *server {
} }
var authkeys bytes.Buffer var authkeys bytes.Buffer
for k, _ := range testdata.PEMBytes { for k := range testdata.PEMBytes {
authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k])) authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
} }
writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes()) writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())

View file

@ -76,17 +76,17 @@ type connectionState struct {
// both directions are triggered by reading and writing a msgNewKey packet // both directions are triggered by reading and writing a msgNewKey packet
// respectively. // respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
if err != nil {
return err return err
} else {
t.reader.pendingKeyChange <- ciph
} }
t.reader.pendingKeyChange <- ciph
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
if err != nil {
return err return err
} else {
t.writer.pendingKeyChange <- ciph
} }
t.writer.pendingKeyChange <- ciph
return nil return nil
} }
@ -139,7 +139,7 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
case cipher := <-s.pendingKeyChange: case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher s.packetCipher = cipher
default: default:
return nil, errors.New("ssh: got bogus newkeys message.") return nil, errors.New("ssh: got bogus newkeys message")
} }
case msgDisconnect: case msgDisconnect:

View file

@ -5,7 +5,6 @@
// Package tea implements the TEA algorithm, as defined in Needham and // Package tea implements the TEA algorithm, as defined in Needham and
// Wheeler's 1994 technical report, “TEA, a Tiny Encryption Algorithm”. See // Wheeler's 1994 technical report, “TEA, a Tiny Encryption Algorithm”. See
// http://www.cix.co.uk/~klockstone/tea.pdf for details. // http://www.cix.co.uk/~klockstone/tea.pdf for details.
package tea package tea
import ( import (

View file

@ -69,7 +69,7 @@ func initCipher(c *Cipher, key []byte) {
// Precalculate the table // Precalculate the table
const delta = 0x9E3779B9 const delta = 0x9E3779B9
var sum uint32 = 0 var sum uint32
// Two rounds of XTEA applied per loop // Two rounds of XTEA applied per loop
for i := 0; i < numRounds; { for i := 0; i < numRounds; {

13
vendor/golang.org/x/net/http/httpproxy/go19_test.go generated vendored Normal file
View file

@ -0,0 +1,13 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.9
package httpproxy_test
import "testing"
func init() {
setHelper = func(t *testing.T) { t.Helper() }
}

View file

@ -16,6 +16,9 @@ import (
"golang.org/x/net/http/httpproxy" "golang.org/x/net/http/httpproxy"
) )
// setHelper calls t.Helper() for Go 1.9+ (see go19_test.go) and does nothing otherwise.
var setHelper = func(t *testing.T) {}
type proxyForURLTest struct { type proxyForURLTest struct {
cfg httpproxy.Config cfg httpproxy.Config
req string // URL to fetch; blank means "http://example.com" req string // URL to fetch; blank means "http://example.com"
@ -166,7 +169,7 @@ var proxyForURLTests = []proxyForURLTest{{
}} }}
func testProxyForURL(t *testing.T, tt proxyForURLTest) { func testProxyForURL(t *testing.T, tt proxyForURLTest) {
t.Helper() setHelper(t)
reqURLStr := tt.req reqURLStr := tt.req
if reqURLStr == "" { if reqURLStr == "" {
reqURLStr = "http://example.com" reqURLStr = "http://example.com"

View file

@ -46,7 +46,6 @@ func TestServerGracefulShutdown(t *testing.T) {
wanth := [][2]string{ wanth := [][2]string{
{":status", "200"}, {":status", "200"},
{"x-foo", "bar"}, {"x-foo", "bar"},
{"content-type", "text/plain; charset=utf-8"},
{"content-length", "0"}, {"content-length", "0"},
} }
if !reflect.DeepEqual(goth, wanth) { if !reflect.DeepEqual(goth, wanth) {

View file

@ -2322,7 +2322,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
clen = strconv.Itoa(len(p)) clen = strconv.Itoa(len(p))
} }
_, hasContentType := rws.snapHeader["Content-Type"] _, hasContentType := rws.snapHeader["Content-Type"]
if !hasContentType && bodyAllowedForStatus(rws.status) { if !hasContentType && bodyAllowedForStatus(rws.status) && len(p) > 0 {
ctype = http.DetectContentType(p) ctype = http.DetectContentType(p)
} }
var date string var date string
@ -2490,7 +2490,26 @@ func (w *responseWriter) Header() http.Header {
return rws.handlerHeader return rws.handlerHeader
} }
// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode.
func checkWriteHeaderCode(code int) {
// Issue 22880: require valid WriteHeader status codes.
// For now we only enforce that it's three digits.
// In the future we might block things over 599 (600 and above aren't defined
// at http://httpwg.org/specs/rfc7231.html#status.codes)
// and we might block under 200 (once we have more mature 1xx support).
// But for now any three digits.
//
// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
// no equivalent bogus thing we can realistically send in HTTP/2,
// so we'll consistently panic instead and help people find their bugs
// early. (We can't return an error from WriteHeader even if we wanted to.)
if code < 100 || code > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", code))
}
}
func (w *responseWriter) WriteHeader(code int) { func (w *responseWriter) WriteHeader(code int) {
checkWriteHeaderCode(code)
rws := w.rws rws := w.rws
if rws == nil { if rws == nil {
panic("WriteHeader called after Handler finished") panic("WriteHeader called after Handler finished")

View file

@ -1718,7 +1718,6 @@ func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
wanth := [][2]string{ wanth := [][2]string{
{":status", "200"}, {":status", "200"},
{"foo-bar", "some-value"}, {"foo-bar", "some-value"},
{"content-type", "text/plain; charset=utf-8"},
{"content-length", "0"}, {"content-length", "0"},
} }
if !reflect.DeepEqual(goth, wanth) { if !reflect.DeepEqual(goth, wanth) {
@ -2953,7 +2952,6 @@ func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
wanth := [][2]string{ wanth := [][2]string{
{":status", "200"}, {":status", "200"},
{"ok1", "x"}, {"ok1", "x"},
{"content-type", "text/plain; charset=utf-8"},
{"content-length", "0"}, {"content-length", "0"},
} }
if !reflect.DeepEqual(goth, wanth) { if !reflect.DeepEqual(goth, wanth) {
@ -3266,7 +3264,6 @@ func TestServerNoAutoContentLengthOnHead(t *testing.T) {
headers := st.decodeHeader(h.HeaderBlockFragment()) headers := st.decodeHeader(h.HeaderBlockFragment())
want := [][2]string{ want := [][2]string{
{":status", "200"}, {":status", "200"},
{"content-type", "text/plain; charset=utf-8"},
} }
if !reflect.DeepEqual(headers, want) { if !reflect.DeepEqual(headers, want) {
t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want) t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)

View file

@ -811,7 +811,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
cc.wmu.Lock() cc.wmu.Lock()
endStream := !hasBody && !hasTrailers endStream := !hasBody && !hasTrailers
werr := cc.writeHeaders(cs.ID, endStream, hdrs) werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
cc.wmu.Unlock() cc.wmu.Unlock()
traceWroteHeaders(cs.trace) traceWroteHeaders(cs.trace)
cc.mu.Unlock() cc.mu.Unlock()
@ -964,13 +964,12 @@ func (cc *ClientConn) awaitOpenSlotForRequest(req *http.Request) error {
} }
// requires cc.wmu be held // requires cc.wmu be held
func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error { func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error {
first := true // first frame written (HEADERS is first, then CONTINUATION) first := true // first frame written (HEADERS is first, then CONTINUATION)
frameSize := int(cc.maxFrameSize)
for len(hdrs) > 0 && cc.werr == nil { for len(hdrs) > 0 && cc.werr == nil {
chunk := hdrs chunk := hdrs
if len(chunk) > frameSize { if len(chunk) > maxFrameSize {
chunk = chunk[:frameSize] chunk = chunk[:maxFrameSize]
} }
hdrs = hdrs[len(chunk):] hdrs = hdrs[len(chunk):]
endHeaders := len(hdrs) == 0 endHeaders := len(hdrs) == 0
@ -1087,13 +1086,17 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
} }
} }
cc.mu.Lock()
maxFrameSize := int(cc.maxFrameSize)
cc.mu.Unlock()
cc.wmu.Lock() cc.wmu.Lock()
defer cc.wmu.Unlock() defer cc.wmu.Unlock()
// Two ways to send END_STREAM: either with trailers, or // Two ways to send END_STREAM: either with trailers, or
// with an empty DATA frame. // with an empty DATA frame.
if len(trls) > 0 { if len(trls) > 0 {
err = cc.writeHeaders(cs.ID, true, trls) err = cc.writeHeaders(cs.ID, true, maxFrameSize, trls)
} else { } else {
err = cc.fr.WriteData(cs.ID, true, nil) err = cc.fr.WriteData(cs.ID, true, nil)
} }
@ -1373,17 +1376,12 @@ func (cc *ClientConn) streamByID(id uint32, andRemove bool) *clientStream {
// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
type clientConnReadLoop struct { type clientConnReadLoop struct {
cc *ClientConn cc *ClientConn
activeRes map[uint32]*clientStream // keyed by streamID
closeWhenIdle bool closeWhenIdle bool
} }
// readLoop runs in its own goroutine and reads and dispatches frames. // readLoop runs in its own goroutine and reads and dispatches frames.
func (cc *ClientConn) readLoop() { func (cc *ClientConn) readLoop() {
rl := &clientConnReadLoop{ rl := &clientConnReadLoop{cc: cc}
cc: cc,
activeRes: make(map[uint32]*clientStream),
}
defer rl.cleanup() defer rl.cleanup()
cc.readerErr = rl.run() cc.readerErr = rl.run()
if ce, ok := cc.readerErr.(ConnectionError); ok { if ce, ok := cc.readerErr.(ConnectionError); ok {
@ -1438,10 +1436,8 @@ func (rl *clientConnReadLoop) cleanup() {
} else if err == io.EOF { } else if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
for _, cs := range rl.activeRes {
cs.bufPipe.CloseWithError(err)
}
for _, cs := range cc.streams { for _, cs := range cc.streams {
cs.bufPipe.CloseWithError(err) // no-op if already closed
select { select {
case cs.resc <- resAndError{err: err}: case cs.resc <- resAndError{err: err}:
default: default:
@ -1519,7 +1515,7 @@ func (rl *clientConnReadLoop) run() error {
} }
return err return err
} }
if rl.closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 { if rl.closeWhenIdle && gotReply && maybeIdle {
cc.closeIfIdle() cc.closeIfIdle()
} }
} }
@ -1527,6 +1523,13 @@ func (rl *clientConnReadLoop) run() error {
func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error { func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
cc := rl.cc cc := rl.cc
cs := cc.streamByID(f.StreamID, false)
if cs == nil {
// We'd get here if we canceled a request while the
// server had its response still in flight. So if this
// was just something we canceled, ignore it.
return nil
}
if f.StreamEnded() { if f.StreamEnded() {
// Issue 20521: If the stream has ended, streamByID() causes // Issue 20521: If the stream has ended, streamByID() causes
// clientStream.done to be closed, which causes the request's bodyWriter // clientStream.done to be closed, which causes the request's bodyWriter
@ -1535,14 +1538,15 @@ func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
// Deferring stream closure allows the header processing to occur first. // Deferring stream closure allows the header processing to occur first.
// clientConn.RoundTrip may still receive the bodyWriter error first, but // clientConn.RoundTrip may still receive the bodyWriter error first, but
// the fix for issue 16102 prioritises any response. // the fix for issue 16102 prioritises any response.
defer cc.streamByID(f.StreamID, true) //
} // Issue 22413: If there is no request body, we should close the
cs := cc.streamByID(f.StreamID, false) // stream before writing to cs.resc so that the stream is closed
if cs == nil { // immediately once RoundTrip returns.
// We'd get here if we canceled a request while the if cs.req.Body != nil {
// server had its response still in flight. So if this defer cc.forgetStreamID(f.StreamID)
// was just something we canceled, ignore it. } else {
return nil cc.forgetStreamID(f.StreamID)
}
} }
if !cs.firstByte { if !cs.firstByte {
if cs.trace != nil { if cs.trace != nil {
@ -1567,6 +1571,7 @@ func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
} }
// Any other error type is a stream error. // Any other error type is a stream error.
cs.cc.writeStreamReset(f.StreamID, ErrCodeProtocol, err) cs.cc.writeStreamReset(f.StreamID, ErrCodeProtocol, err)
cc.forgetStreamID(cs.ID)
cs.resc <- resAndError{err: err} cs.resc <- resAndError{err: err}
return nil // return nil from process* funcs to keep conn alive return nil // return nil from process* funcs to keep conn alive
} }
@ -1574,9 +1579,6 @@ func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
// (nil, nil) special case. See handleResponse docs. // (nil, nil) special case. See handleResponse docs.
return nil return nil
} }
if res.Body != noBody {
rl.activeRes[cs.ID] = cs
}
cs.resTrailer = &res.Trailer cs.resTrailer = &res.Trailer
cs.resc <- resAndError{res: res} cs.resc <- resAndError{res: res}
return nil return nil
@ -1596,11 +1598,11 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
status := f.PseudoValue("status") status := f.PseudoValue("status")
if status == "" { if status == "" {
return nil, errors.New("missing status pseudo header") return nil, errors.New("malformed response from server: missing status pseudo header")
} }
statusCode, err := strconv.Atoi(status) statusCode, err := strconv.Atoi(status)
if err != nil { if err != nil {
return nil, errors.New("malformed non-numeric status pseudo header") return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header")
} }
if statusCode == 100 { if statusCode == 100 {
@ -1915,7 +1917,6 @@ func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) {
rl.closeWhenIdle = true rl.closeWhenIdle = true
} }
cs.bufPipe.closeWithErrorAndCode(err, code) cs.bufPipe.closeWithErrorAndCode(err, code)
delete(rl.activeRes, cs.ID)
select { select {
case cs.resc <- resAndError{err: err}: case cs.resc <- resAndError{err: err}:
@ -2042,7 +2043,6 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error {
cs.bufPipe.CloseWithError(err) cs.bufPipe.CloseWithError(err)
cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl
} }
delete(rl.activeRes, cs.ID)
return nil return nil
} }

View file

@ -13,6 +13,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
@ -2291,6 +2292,11 @@ func TestTransportReadHeadResponse(t *testing.T) {
} }
func TestTransportReadHeadResponseWithBody(t *testing.T) { func TestTransportReadHeadResponseWithBody(t *testing.T) {
// This test use not valid response format.
// Discarding logger output to not spam tests output.
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
response := "redirecting to /elsewhere" response := "redirecting to /elsewhere"
ct := newClientTester(t) ct := newClientTester(t)
clientDone := make(chan struct{}) clientDone := make(chan struct{})
@ -3383,6 +3389,11 @@ func TestTransportRetryHasLimit(t *testing.T) {
} }
func TestTransportResponseDataBeforeHeaders(t *testing.T) { func TestTransportResponseDataBeforeHeaders(t *testing.T) {
// This test use not valid response format.
// Discarding logger output to not spam tests output.
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
ct := newClientTester(t) ct := newClientTester(t)
ct.client = func() error { ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite() defer ct.cc.(*net.TCPConn).CloseWrite()
@ -3788,6 +3799,46 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
} }
} }
// Verify transport doesn't crash when receiving bogus response lacking a :status header.
// Issue 22880.
func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
ct := newClientTester(t)
ct.client = func() error {
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
_, err := ct.tr.RoundTrip(req)
const substr = "malformed response from server: missing status pseudo header"
if !strings.Contains(fmt.Sprint(err), substr) {
return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr)
}
return nil
}
ct.server = func() error {
ct.greet()
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
for {
f, err := ct.fr.ReadFrame()
if err != nil {
return err
}
switch f := f.(type) {
case *HeadersFrame:
enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header
ct.fr.WriteHeaders(HeadersFrameParam{
StreamID: f.StreamID,
EndHeaders: true,
EndStream: false, // we'll send some DATA to try to crash the transport
BlockFragment: buf.Bytes(),
})
ct.fr.WriteData(f.StreamID, true, []byte("payload"))
return nil
}
}
}
ct.run()
}
func BenchmarkClientRequestHeaders(b *testing.B) { func BenchmarkClientRequestHeaders(b *testing.B) {
b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0) }) b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0) })
b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10) }) b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10) })

View file

@ -74,6 +74,11 @@ func TestableNetwork(network string) bool {
switch runtime.GOOS { switch runtime.GOOS {
case "android", "darwin", "freebsd", "nacl", "plan9", "windows": case "android", "darwin", "freebsd", "nacl", "plan9", "windows":
return false return false
case "netbsd":
// It passes on amd64 at least. 386 fails (Issue 22927). arm is unknown.
if runtime.GOARCH == "386" {
return false
}
} }
} }
return true return true

View file

@ -110,7 +110,7 @@ func ControlMessageSpace(dataLen int) int {
type ControlMessage []byte type ControlMessage []byte
// Data returns the data field of the control message at the head on // Data returns the data field of the control message at the head on
// w. // m.
func (m ControlMessage) Data(dataLen int) []byte { func (m ControlMessage) Data(dataLen int) []byte {
l := controlHeaderLen() l := controlHeaderLen()
if len(m) < l || len(m) < l+dataLen { if len(m) < l || len(m) < l+dataLen {
@ -119,7 +119,7 @@ func (m ControlMessage) Data(dataLen int) []byte {
return m[l : l+dataLen] return m[l : l+dataLen]
} }
// Next returns the control message at the next on w. // Next returns the control message at the next on m.
// //
// Next works only for standard control messages. // Next works only for standard control messages.
func (m ControlMessage) Next(dataLen int) ControlMessage { func (m ControlMessage) Next(dataLen int) ControlMessage {
@ -131,7 +131,7 @@ func (m ControlMessage) Next(dataLen int) ControlMessage {
} }
// MarshalHeader marshals the header fields of the control message at // MarshalHeader marshals the header fields of the control message at
// the head on w. // the head on m.
func (m ControlMessage) MarshalHeader(lvl, typ, dataLen int) error { func (m ControlMessage) MarshalHeader(lvl, typ, dataLen int) error {
if len(m) < controlHeaderLen() { if len(m) < controlHeaderLen() {
return errors.New("short message") return errors.New("short message")
@ -142,7 +142,7 @@ func (m ControlMessage) MarshalHeader(lvl, typ, dataLen int) error {
} }
// ParseHeader parses and returns the header fields of the control // ParseHeader parses and returns the header fields of the control
// message at the head on w. // message at the head on m.
func (m ControlMessage) ParseHeader() (lvl, typ, dataLen int, err error) { func (m ControlMessage) ParseHeader() (lvl, typ, dataLen int, err error) {
l := controlHeaderLen() l := controlHeaderLen()
if len(m) < l { if len(m) < l {
@ -152,7 +152,7 @@ func (m ControlMessage) ParseHeader() (lvl, typ, dataLen int, err error) {
return h.lvl(), h.typ(), int(uint64(h.len()) - uint64(l)), nil return h.lvl(), h.typ(), int(uint64(h.len()) - uint64(l)), nil
} }
// Marshal marshals the control message at the head on w, and returns // Marshal marshals the control message at the head on m, and returns
// the next control message. // the next control message.
func (m ControlMessage) Marshal(lvl, typ int, data []byte) (ControlMessage, error) { func (m ControlMessage) Marshal(lvl, typ int, data []byte) (ControlMessage, error) {
l := len(data) l := len(data)
@ -167,7 +167,7 @@ func (m ControlMessage) Marshal(lvl, typ int, data []byte) (ControlMessage, erro
return m.Next(l), nil return m.Next(l), nil
} }
// Parse parses w as a single or multiple control messages. // Parse parses m as a single or multiple control messages.
// //
// Parse works for both standard and compatible messages. // Parse works for both standard and compatible messages.
func (m ControlMessage) Parse() ([]ControlMessage, error) { func (m ControlMessage) Parse() ([]ControlMessage, error) {

View file

@ -0,0 +1,61 @@
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs defs_darwin.go
package socket
const (
sysAF_UNSPEC = 0x0
sysAF_INET = 0x2
sysAF_INET6 = 0x1e
sysSOCK_RAW = 0x3
)
type iovec struct {
Base *byte
Len uint64
}
type msghdr struct {
Name *byte
Namelen uint32
Pad_cgo_0 [4]byte
Iov *iovec
Iovlen int32
Pad_cgo_1 [4]byte
Control *byte
Controllen uint32
Flags int32
}
type cmsghdr struct {
Len uint32
Level int32
Type int32
}
type sockaddrInet struct {
Len uint8
Family uint8
Port uint16
Addr [4]byte /* in_addr */
Zero [8]int8
}
type sockaddrInet6 struct {
Len uint8
Family uint8
Port uint16
Flowinfo uint32
Addr [16]byte /* in6_addr */
Scope_id uint32
}
const (
sizeofIovec = 0x10
sizeofMsghdr = 0x30
sizeofCmsghdr = 0xc
sizeofSockaddrInet = 0x10
sizeofSockaddrInet6 = 0x1c
)

View file

@ -6,6 +6,7 @@ package internal
import ( import (
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -32,6 +33,7 @@ func TestRetrieveTokenBustedNoSecret(t *testing.T) {
if got, want := r.FormValue("client_secret"), ""; got != want { if got, want := r.FormValue("client_secret"), ""; got != want {
t.Errorf("client_secret = %q; want empty", got) t.Errorf("client_secret = %q; want empty", got)
} }
io.WriteString(w, "{}") // something non-empty, required to set a Content-Type in Go 1.10
})) }))
defer ts.Close() defer ts.Close()
@ -82,7 +84,9 @@ func TestProviderAuthHeaderWorksDomain(t *testing.T) {
func TestRetrieveTokenWithContexts(t *testing.T) { func TestRetrieveTokenWithContexts(t *testing.T) {
const clientID = "client-id" const clientID = "client-id"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "{}") // something non-empty, required to set a Content-Type in Go 1.10
}))
defer ts.Close() defer ts.Close()
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}) _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{})

16
vendor/golang.org/x/oauth2/mailru/mailru.go generated vendored Normal file
View file

@ -0,0 +1,16 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package mailru provides constants for using OAuth2 to access Mail.Ru.
package mailru // import "golang.org/x/oauth2/mailru"
import (
"golang.org/x/oauth2"
)
// Endpoint is Mail.Ru's OAuth 2.0 endpoint.
var Endpoint = oauth2.Endpoint{
AuthURL: "https://o2.mail.ru/login",
TokenURL: "https://o2.mail.ru/token",
}

View file

@ -14,3 +14,18 @@ var LiveConnectEndpoint = oauth2.Endpoint{
AuthURL: "https://login.live.com/oauth20_authorize.srf", AuthURL: "https://login.live.com/oauth20_authorize.srf",
TokenURL: "https://login.live.com/oauth20_token.srf", TokenURL: "https://login.live.com/oauth20_token.srf",
} }
// AzureADEndpoint returns a new oauth2.Endpoint for the given tenant at Azure Active Directory.
// If tenant is empty, it uses the tenant called `common`.
//
// For more information see:
// https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-v2-protocols#endpoints
func AzureADEndpoint(tenant string) oauth2.Endpoint {
if tenant == "" {
tenant = "common"
}
return oauth2.Endpoint{
AuthURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/authorize",
TokenURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/token",
}
}

View file

@ -123,7 +123,7 @@ func (t *Token) expired() bool {
if t.Expiry.IsZero() { if t.Expiry.IsZero() {
return false return false
} }
return t.Expiry.Add(-expiryDelta).Before(time.Now()) return t.Expiry.Round(0).Add(-expiryDelta).Before(time.Now())
} }
// Valid reports whether t is non-nil, has an AccessToken, and is not expired. // Valid reports whether t is non-nil, has an AccessToken, and is not expired.

19
vendor/golang.org/x/oauth2/twitch/twitch.go generated vendored Normal file
View file

@ -0,0 +1,19 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package twitch provides constants for using OAuth2 to access Twitch.
package twitch // import "golang.org/x/oauth2/twitch"
import (
"golang.org/x/oauth2"
)
// Endpoint is Twitch's OAuth 2.0 endpoint.
//
// For more information see:
// https://dev.twitch.tv/docs/authentication
var Endpoint = oauth2.Endpoint{
AuthURL: "https://api.twitch.tv/kraken/oauth2/authorize",
TokenURL: "https://api.twitch.tv/kraken/oauth2/token",
}

17
vendor/golang.org/x/oauth2/yahoo/yahoo.go generated vendored Normal file
View file

@ -0,0 +1,17 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package yahoo provides constants for using OAuth2 to access Yahoo.
package yahoo // import "golang.org/x/oauth2/yahoo"
import (
"golang.org/x/oauth2"
)
// Endpoint is Yahoo's OAuth 2.0 endpoint.
// See https://developer.yahoo.com/oauth2/guide/
var Endpoint = oauth2.Endpoint{
AuthURL: "https://api.login.yahoo.com/oauth2/request_auth",
TokenURL: "https://api.login.yahoo.com/oauth2/get_token",
}

View file

@ -1 +1,2 @@
_obj/ _obj/
unix.test

View file

@ -8,6 +8,7 @@ package unix_test
import ( import (
"bytes" "bytes"
"go/build"
"net" "net"
"os" "os"
"syscall" "syscall"
@ -35,6 +36,11 @@ func TestSCMCredentials(t *testing.T) {
} }
for _, tt := range socketTypeTests { for _, tt := range socketTypeTests {
if tt.socketType == unix.SOCK_DGRAM && !atLeast1p10() {
t.Log("skipping DGRAM test on pre-1.10")
continue
}
fds, err := unix.Socketpair(unix.AF_LOCAL, tt.socketType, 0) fds, err := unix.Socketpair(unix.AF_LOCAL, tt.socketType, 0)
if err != nil { if err != nil {
t.Fatalf("Socketpair: %v", err) t.Fatalf("Socketpair: %v", err)
@ -134,3 +140,13 @@ func TestSCMCredentials(t *testing.T) {
} }
} }
} }
// atLeast1p10 reports whether we are running on Go 1.10 or later.
func atLeast1p10() bool {
for _, ver := range build.Default.ReleaseTags {
if ver == "go1.10" {
return true
}
}
return false
}

View file

@ -20,11 +20,9 @@ func TestDevices(t *testing.T) {
minor uint32 minor uint32
}{ }{
// well known major/minor numbers according to /dev/MAKEDEV on // well known major/minor numbers according to /dev/MAKEDEV on
// NetBSD 7.0 // NetBSD 8.0
{"/dev/null", 2, 2}, {"/dev/null", 2, 2},
{"/dev/zero", 2, 12}, {"/dev/zero", 2, 12},
{"/dev/ttyp0", 5, 0},
{"/dev/ttyp1", 5, 1},
{"/dev/random", 46, 0}, {"/dev/random", 46, 0},
{"/dev/urandom", 46, 1}, {"/dev/urandom", 46, 1},
} }

View file

@ -352,6 +352,18 @@ func GetsockoptICMPv6Filter(fd, level, opt int) (*ICMPv6Filter, error) {
return &value, err return &value, err
} }
// GetsockoptString returns the string value of the socket option opt for the
// socket associated with fd at the given socket level.
func GetsockoptString(fd, level, opt int) (string, error) {
buf := make([]byte, 256)
vallen := _Socklen(len(buf))
err := getsockopt(fd, level, opt, unsafe.Pointer(&buf[0]), &vallen)
if err != nil {
return "", err
}
return string(buf[:vallen-1]), nil
}
//sys recvfrom(fd int, p []byte, flags int, from *RawSockaddrAny, fromlen *_Socklen) (n int, err error) //sys recvfrom(fd int, p []byte, flags int, from *RawSockaddrAny, fromlen *_Socklen) (n int, err error)
//sys sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen _Socklen) (err error) //sys sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen _Socklen) (err error)
//sys recvmsg(s int, msg *Msghdr, flags int) (n int, err error) //sys recvmsg(s int, msg *Msghdr, flags int) (n int, err error)

View file

@ -110,6 +110,23 @@ func Accept4(fd, flags int) (nfd int, sa Sockaddr, err error) {
return return
} }
const ImplementsGetwd = true
//sys Getcwd(buf []byte) (n int, err error) = SYS___GETCWD
func Getwd() (string, error) {
var buf [PathMax]byte
_, err := Getcwd(buf[0:])
if err != nil {
return "", err
}
n := clen(buf[:])
if n < 1 {
return "", EINVAL
}
return string(buf[:n]), nil
}
func Getfsstat(buf []Statfs_t, flags int) (n int, err error) { func Getfsstat(buf []Statfs_t, flags int) (n int, err error) {
var _p0 unsafe.Pointer var _p0 unsafe.Pointer
var bufsize uintptr var bufsize uintptr
@ -169,6 +186,69 @@ func IoctlGetTermios(fd int, req uint) (*Termios, error) {
return &value, err return &value, err
} }
func sysctlUname(mib []_C_int, old *byte, oldlen *uintptr) error {
err := sysctl(mib, old, oldlen, nil, 0)
if err != nil {
// Utsname members on Dragonfly are only 32 bytes and
// the syscall returns ENOMEM in case the actual value
// is longer.
if err == ENOMEM {
err = nil
}
}
return err
}
func Uname(uname *Utsname) error {
mib := []_C_int{CTL_KERN, KERN_OSTYPE}
n := unsafe.Sizeof(uname.Sysname)
if err := sysctlUname(mib, &uname.Sysname[0], &n); err != nil {
return err
}
uname.Sysname[unsafe.Sizeof(uname.Sysname)-1] = 0
mib = []_C_int{CTL_KERN, KERN_HOSTNAME}
n = unsafe.Sizeof(uname.Nodename)
if err := sysctlUname(mib, &uname.Nodename[0], &n); err != nil {
return err
}
uname.Nodename[unsafe.Sizeof(uname.Nodename)-1] = 0
mib = []_C_int{CTL_KERN, KERN_OSRELEASE}
n = unsafe.Sizeof(uname.Release)
if err := sysctlUname(mib, &uname.Release[0], &n); err != nil {
return err
}
uname.Release[unsafe.Sizeof(uname.Release)-1] = 0
mib = []_C_int{CTL_KERN, KERN_VERSION}
n = unsafe.Sizeof(uname.Version)
if err := sysctlUname(mib, &uname.Version[0], &n); err != nil {
return err
}
// The version might have newlines or tabs in it, convert them to
// spaces.
for i, b := range uname.Version {
if b == '\n' || b == '\t' {
if i == len(uname.Version)-1 {
uname.Version[i] = 0
} else {
uname.Version[i] = ' '
}
}
}
mib = []_C_int{CTL_HW, HW_MACHINE}
n = unsafe.Sizeof(uname.Machine)
if err := sysctlUname(mib, &uname.Machine[0], &n); err != nil {
return err
}
uname.Machine[unsafe.Sizeof(uname.Machine)-1] = 0
return nil
}
/* /*
* Exposed directly * Exposed directly
*/ */

View file

@ -105,6 +105,23 @@ func Accept4(fd, flags int) (nfd int, sa Sockaddr, err error) {
return return
} }
const ImplementsGetwd = true
//sys Getcwd(buf []byte) (n int, err error) = SYS___GETCWD
func Getwd() (string, error) {
var buf [PathMax]byte
_, err := Getcwd(buf[0:])
if err != nil {
return "", err
}
n := clen(buf[:])
if n < 1 {
return "", EINVAL
}
return string(buf[:n]), nil
}
func Getfsstat(buf []Statfs_t, flags int) (n int, err error) { func Getfsstat(buf []Statfs_t, flags int) (n int, err error) {
var _p0 unsafe.Pointer var _p0 unsafe.Pointer
var bufsize uintptr var bufsize uintptr
@ -396,6 +413,52 @@ func IoctlGetTermios(fd int, req uint) (*Termios, error) {
return &value, err return &value, err
} }
func Uname(uname *Utsname) error {
mib := []_C_int{CTL_KERN, KERN_OSTYPE}
n := unsafe.Sizeof(uname.Sysname)
if err := sysctl(mib, &uname.Sysname[0], &n, nil, 0); err != nil {
return err
}
mib = []_C_int{CTL_KERN, KERN_HOSTNAME}
n = unsafe.Sizeof(uname.Nodename)
if err := sysctl(mib, &uname.Nodename[0], &n, nil, 0); err != nil {
return err
}
mib = []_C_int{CTL_KERN, KERN_OSRELEASE}
n = unsafe.Sizeof(uname.Release)
if err := sysctl(mib, &uname.Release[0], &n, nil, 0); err != nil {
return err
}
mib = []_C_int{CTL_KERN, KERN_VERSION}
n = unsafe.Sizeof(uname.Version)
if err := sysctl(mib, &uname.Version[0], &n, nil, 0); err != nil {
return err
}
// The version might have newlines or tabs in it, convert them to
// spaces.
for i, b := range uname.Version {
if b == '\n' || b == '\t' {
if i == len(uname.Version)-1 {
uname.Version[i] = 0
} else {
uname.Version[i] = ' '
}
}
}
mib = []_C_int{CTL_HW, HW_MACHINE}
n = unsafe.Sizeof(uname.Machine)
if err := sysctl(mib, &uname.Machine[0], &n, nil, 0); err != nil {
return err
}
return nil
}
/* /*
* Exposed directly * Exposed directly
*/ */
@ -439,6 +502,7 @@ func IoctlGetTermios(fd int, req uint) (*Termios, error) {
//sys Fstatfs(fd int, stat *Statfs_t) (err error) //sys Fstatfs(fd int, stat *Statfs_t) (err error)
//sys Fsync(fd int) (err error) //sys Fsync(fd int) (err error)
//sys Ftruncate(fd int, length int64) (err error) //sys Ftruncate(fd int, length int64) (err error)
//sys Getdents(fd int, buf []byte) (n int, err error)
//sys Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) //sys Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error)
//sys Getdtablesize() (size int) //sys Getdtablesize() (size int)
//sysnb Getegid() (egid int) //sysnb Getegid() (egid int)

View file

@ -808,6 +808,24 @@ func GetsockoptTCPInfo(fd, level, opt int) (*TCPInfo, error) {
return &value, err return &value, err
} }
// GetsockoptString returns the string value of the socket option opt for the
// socket associated with fd at the given socket level.
func GetsockoptString(fd, level, opt int) (string, error) {
buf := make([]byte, 256)
vallen := _Socklen(len(buf))
err := getsockopt(fd, level, opt, unsafe.Pointer(&buf[0]), &vallen)
if err != nil {
if err == ERANGE {
buf = make([]byte, vallen)
err = getsockopt(fd, level, opt, unsafe.Pointer(&buf[0]), &vallen)
}
if err != nil {
return "", err
}
}
return string(buf[:vallen-1]), nil
}
func SetsockoptIPMreqn(fd, level, opt int, mreq *IPMreqn) (err error) { func SetsockoptIPMreqn(fd, level, opt int, mreq *IPMreqn) (err error) {
return setsockopt(fd, level, opt, unsafe.Pointer(mreq), unsafe.Sizeof(*mreq)) return setsockopt(fd, level, opt, unsafe.Pointer(mreq), unsafe.Sizeof(*mreq))
} }

View file

@ -184,17 +184,6 @@ func TestSelect(t *testing.T) {
} }
} }
func TestUname(t *testing.T) {
var utsname unix.Utsname
err := unix.Uname(&utsname)
if err != nil {
t.Fatalf("Uname: %v", err)
}
// conversion from []byte to string, golang.org/issue/20753
t.Logf("OS: %s/%s %s", string(utsname.Sysname[:]), string(utsname.Machine[:]), string(utsname.Release[:]))
}
func TestFstatat(t *testing.T) { func TestFstatat(t *testing.T) {
defer chtmpdir(t)() defer chtmpdir(t)()

View file

@ -118,6 +118,23 @@ func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
return getdents(fd, buf) return getdents(fd, buf)
} }
const ImplementsGetwd = true
//sys Getcwd(buf []byte) (n int, err error) = SYS___GETCWD
func Getwd() (string, error) {
var buf [PathMax]byte
_, err := Getcwd(buf[0:])
if err != nil {
return "", err
}
n := clen(buf[:])
if n < 1 {
return "", EINVAL
}
return string(buf[:n]), nil
}
// TODO // TODO
func sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) { func sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) {
return -1, ENOSYS return -1, ENOSYS
@ -167,6 +184,52 @@ func IoctlGetTermios(fd int, req uint) (*Termios, error) {
return &value, err return &value, err
} }
func Uname(uname *Utsname) error {
mib := []_C_int{CTL_KERN, KERN_OSTYPE}
n := unsafe.Sizeof(uname.Sysname)
if err := sysctl(mib, &uname.Sysname[0], &n, nil, 0); err != nil {
return err
}
mib = []_C_int{CTL_KERN, KERN_HOSTNAME}
n = unsafe.Sizeof(uname.Nodename)
if err := sysctl(mib, &uname.Nodename[0], &n, nil, 0); err != nil {
return err
}
mib = []_C_int{CTL_KERN, KERN_OSRELEASE}
n = unsafe.Sizeof(uname.Release)
if err := sysctl(mib, &uname.Release[0], &n, nil, 0); err != nil {
return err
}
mib = []_C_int{CTL_KERN, KERN_VERSION}
n = unsafe.Sizeof(uname.Version)
if err := sysctl(mib, &uname.Version[0], &n, nil, 0); err != nil {
return err
}
// The version might have newlines or tabs in it, convert them to
// spaces.
for i, b := range uname.Version {
if b == '\n' || b == '\t' {
if i == len(uname.Version)-1 {
uname.Version[i] = 0
} else {
uname.Version[i] = ' '
}
}
}
mib = []_C_int{CTL_HW, HW_MACHINE}
n = unsafe.Sizeof(uname.Machine)
if err := sysctl(mib, &uname.Machine[0], &n, nil, 0); err != nil {
return err
}
return nil
}
/* /*
* Exposed directly * Exposed directly
*/ */

View file

@ -1,11 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build dragonfly freebsd netbsd openbsd
package unix
const ImplementsGetwd = false
func Getwd() (string, error) { return "", ENOTSUP }

View file

@ -71,6 +71,23 @@ func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
return getdents(fd, buf) return getdents(fd, buf)
} }
const ImplementsGetwd = true
//sys Getcwd(buf []byte) (n int, err error) = SYS___GETCWD
func Getwd() (string, error) {
var buf [PathMax]byte
_, err := Getcwd(buf[0:])
if err != nil {
return "", err
}
n := clen(buf[:])
if n < 1 {
return "", EINVAL
}
return string(buf[:n]), nil
}
// TODO // TODO
func sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) { func sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) {
return -1, ENOSYS return -1, ENOSYS
@ -135,6 +152,52 @@ func IoctlGetTermios(fd int, req uint) (*Termios, error) {
return &value, err return &value, err
} }
func Uname(uname *Utsname) error {
mib := []_C_int{CTL_KERN, KERN_OSTYPE}
n := unsafe.Sizeof(uname.Sysname)
if err := sysctl(mib, &uname.Sysname[0], &n, nil, 0); err != nil {
return err
}
mib = []_C_int{CTL_KERN, KERN_HOSTNAME}
n = unsafe.Sizeof(uname.Nodename)
if err := sysctl(mib, &uname.Nodename[0], &n, nil, 0); err != nil {
return err
}
mib = []_C_int{CTL_KERN, KERN_OSRELEASE}
n = unsafe.Sizeof(uname.Release)
if err := sysctl(mib, &uname.Release[0], &n, nil, 0); err != nil {
return err
}
mib = []_C_int{CTL_KERN, KERN_VERSION}
n = unsafe.Sizeof(uname.Version)
if err := sysctl(mib, &uname.Version[0], &n, nil, 0); err != nil {
return err
}
// The version might have newlines or tabs in it, convert them to
// spaces.
for i, b := range uname.Version {
if b == '\n' || b == '\t' {
if i == len(uname.Version)-1 {
uname.Version[i] = 0
} else {
uname.Version[i] = ' '
}
}
}
mib = []_C_int{CTL_HW, HW_MACHINE}
n = unsafe.Sizeof(uname.Machine)
if err := sysctl(mib, &uname.Machine[0], &n, nil, 0); err != nil {
return err
}
return nil
}
/* /*
* Exposed directly * Exposed directly
*/ */

View file

@ -34,15 +34,6 @@ type SockaddrDatalink struct {
raw RawSockaddrDatalink raw RawSockaddrDatalink
} }
func clen(n []byte) int {
for i := 0; i < len(n); i++ {
if n[i] == 0 {
return i
}
}
return len(n)
}
func direntIno(buf []byte) (uint64, bool) { func direntIno(buf []byte) (uint64, bool) {
return readInt(buf, unsafe.Offsetof(Dirent{}.Ino), unsafe.Sizeof(Dirent{}.Ino)) return readInt(buf, unsafe.Offsetof(Dirent{}.Ino), unsafe.Sizeof(Dirent{}.Ino))
} }
@ -139,6 +130,18 @@ func Getsockname(fd int) (sa Sockaddr, err error) {
return anyToSockaddr(&rsa) return anyToSockaddr(&rsa)
} }
// GetsockoptString returns the string value of the socket option opt for the
// socket associated with fd at the given socket level.
func GetsockoptString(fd, level, opt int) (string, error) {
buf := make([]byte, 256)
vallen := _Socklen(len(buf))
err := getsockopt(fd, level, opt, unsafe.Pointer(&buf[0]), &vallen)
if err != nil {
return "", err
}
return string(buf[:vallen-1]), nil
}
const ImplementsGetwd = true const ImplementsGetwd = true
//sys Getcwd(buf []byte) (n int, err error) //sys Getcwd(buf []byte) (n int, err error)

View file

@ -48,3 +48,13 @@ func TestItoa(t *testing.T) {
t.Fatalf("itoa(%d) = %s, want %s", i, s, f) t.Fatalf("itoa(%d) = %s, want %s", i, s, f)
} }
} }
func TestUname(t *testing.T) {
var utsname unix.Utsname
err := unix.Uname(&utsname)
if err != nil {
t.Fatalf("Uname: %v", err)
}
t.Logf("OS: %s/%s %s", utsname.Sysname[:], utsname.Machine[:], utsname.Release[:])
}

View file

@ -50,6 +50,17 @@ func errnoErr(e syscall.Errno) error {
return e return e
} }
// clen returns the index of the first NULL byte in n or len(n) if n contains no
// NULL byte or len(n) if n contains no NULL byte
func clen(n []byte) int {
for i := 0; i < len(n); i++ {
if n[i] == 0 {
return i
}
}
return len(n)
}
// Mmap manager, for use by operating system-specific implementations. // Mmap manager, for use by operating system-specific implementations.
type mmapper struct { type mmapper struct {

View file

@ -378,6 +378,54 @@ func TestPoll(t *testing.T) {
} }
} }
func TestGetwd(t *testing.T) {
fd, err := os.Open(".")
if err != nil {
t.Fatalf("Open .: %s", err)
}
defer fd.Close()
// These are chosen carefully not to be symlinks on a Mac
// (unlike, say, /var, /etc)
dirs := []string{"/", "/usr/bin"}
if runtime.GOOS == "darwin" {
switch runtime.GOARCH {
case "arm", "arm64":
d1, err := ioutil.TempDir("", "d1")
if err != nil {
t.Fatalf("TempDir: %v", err)
}
d2, err := ioutil.TempDir("", "d2")
if err != nil {
t.Fatalf("TempDir: %v", err)
}
dirs = []string{d1, d2}
}
}
oldwd := os.Getenv("PWD")
for _, d := range dirs {
err = os.Chdir(d)
if err != nil {
t.Fatalf("Chdir: %v", err)
}
pwd, err := unix.Getwd()
if err != nil {
t.Fatalf("Getwd in %s: %s", d, err)
}
os.Setenv("PWD", oldwd)
err = fd.Chdir()
if err != nil {
// We changed the current directory and cannot go back.
// Don't let the tests continue; they'll scribble
// all over some other directory.
fmt.Fprintf(os.Stderr, "fchdir back to dot failed: %s\n", err)
os.Exit(1)
}
if pwd != d {
t.Fatalf("Getwd returned %q want %q", pwd, d)
}
}
}
// mktmpfifo creates a temporary FIFO and provides a cleanup function. // mktmpfifo creates a temporary FIFO and provides a cleanup function.
func mktmpfifo(t *testing.T) (*os.File, func()) { func mktmpfifo(t *testing.T) (*os.File, func()) {
err := unix.Mkfifo("fifo", 0666) err := unix.Mkfifo("fifo", 0666)

View file

@ -6,6 +6,8 @@
package unix package unix
import "time"
// TimespecToNsec converts a Timespec value into a number of // TimespecToNsec converts a Timespec value into a number of
// nanoseconds since the Unix epoch. // nanoseconds since the Unix epoch.
func TimespecToNsec(ts Timespec) int64 { return int64(ts.Sec)*1e9 + int64(ts.Nsec) } func TimespecToNsec(ts Timespec) int64 { return int64(ts.Sec)*1e9 + int64(ts.Nsec) }
@ -22,6 +24,24 @@ func NsecToTimespec(nsec int64) Timespec {
return setTimespec(sec, nsec) return setTimespec(sec, nsec)
} }
// TimeToTimespec converts t into a Timespec.
// On some 32-bit systems the range of valid Timespec values are smaller
// than that of time.Time values. So if t is out of the valid range of
// Timespec, it returns a zero Timespec and ERANGE.
func TimeToTimespec(t time.Time) (Timespec, error) {
sec := t.Unix()
nsec := int64(t.Nanosecond())
ts := setTimespec(sec, nsec)
// Currently all targets have either int32 or int64 for Timespec.Sec.
// If there were a new target with floating point type for it, we have
// to consider the rounding error.
if int64(ts.Sec) != sec {
return Timespec{}, ERANGE
}
return ts, nil
}
// TimevalToNsec converts a Timeval value into a number of nanoseconds // TimevalToNsec converts a Timeval value into a number of nanoseconds
// since the Unix epoch. // since the Unix epoch.
func TimevalToNsec(tv Timeval) int64 { return int64(tv.Sec)*1e9 + int64(tv.Usec)*1e3 } func TimevalToNsec(tv Timeval) int64 { return int64(tv.Sec)*1e9 + int64(tv.Usec)*1e3 }

54
vendor/golang.org/x/sys/unix/timestruct_test.go generated vendored Normal file
View file

@ -0,0 +1,54 @@
// Copyright 2017 The Go Authors. All right reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
package unix_test
import (
"testing"
"time"
"unsafe"
"golang.org/x/sys/unix"
)
func TestTimeToTimespec(t *testing.T) {
timeTests := []struct {
time time.Time
valid bool
}{
{time.Unix(0, 0), true},
{time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), true},
{time.Date(2262, time.December, 31, 23, 0, 0, 0, time.UTC), false},
{time.Unix(0x7FFFFFFF, 0), true},
{time.Unix(0x80000000, 0), false},
{time.Unix(0x7FFFFFFF, 1000000000), false},
{time.Unix(0x7FFFFFFF, 999999999), true},
{time.Unix(-0x80000000, 0), true},
{time.Unix(-0x80000001, 0), false},
{time.Date(2038, time.January, 19, 3, 14, 7, 0, time.UTC), true},
{time.Date(2038, time.January, 19, 3, 14, 8, 0, time.UTC), false},
{time.Date(1901, time.December, 13, 20, 45, 52, 0, time.UTC), true},
{time.Date(1901, time.December, 13, 20, 45, 51, 0, time.UTC), false},
}
// Currently all targets have either int32 or int64 for Timespec.Sec.
// If there were a new target with unsigned or floating point type for
// it, this test must be adjusted.
have64BitTime := (unsafe.Sizeof(unix.Timespec{}.Sec) == 8)
for _, tt := range timeTests {
ts, err := unix.TimeToTimespec(tt.time)
tt.valid = tt.valid || have64BitTime
if tt.valid && err != nil {
t.Errorf("TimeToTimespec(%v): %v", tt.time, err)
}
if err == nil {
tstime := time.Unix(int64(ts.Sec), int64(ts.Nsec))
if !tstime.Equal(tt.time) {
t.Errorf("TimeToTimespec(%v) is the time %v", tt.time, tstime)
}
}
}
}

View file

@ -35,6 +35,7 @@ package unix
#include <sys/time.h> #include <sys/time.h>
#include <sys/types.h> #include <sys/types.h>
#include <sys/un.h> #include <sys/un.h>
#include <sys/utsname.h>
#include <sys/wait.h> #include <sys/wait.h>
#include <net/bpf.h> #include <net/bpf.h>
#include <net/if.h> #include <net/if.h>
@ -126,6 +127,12 @@ type Dirent C.struct_dirent
type Fsid C.struct_fsid type Fsid C.struct_fsid
// File system limits
const (
PathMax = C.PATH_MAX
)
// Sockets // Sockets
type RawSockaddrInet4 C.struct_sockaddr_in type RawSockaddrInet4 C.struct_sockaddr_in
@ -267,3 +274,7 @@ const (
POLLWRBAND = C.POLLWRBAND POLLWRBAND = C.POLLWRBAND
POLLWRNORM = C.POLLWRNORM POLLWRNORM = C.POLLWRNORM
) )
// Uname
type Utsname C.struct_utsname

View file

@ -36,6 +36,7 @@ package unix
#include <sys/time.h> #include <sys/time.h>
#include <sys/types.h> #include <sys/types.h>
#include <sys/un.h> #include <sys/un.h>
#include <sys/utsname.h>
#include <sys/wait.h> #include <sys/wait.h>
#include <net/bpf.h> #include <net/bpf.h>
#include <net/if.h> #include <net/if.h>
@ -215,6 +216,12 @@ type Dirent C.struct_dirent
type Fsid C.struct_fsid type Fsid C.struct_fsid
// File system limits
const (
PathMax = C.PATH_MAX
)
// Advice to Fadvise // Advice to Fadvise
const ( const (
@ -389,3 +396,7 @@ const (
// Capabilities // Capabilities
type CapRights C.struct_cap_rights type CapRights C.struct_cap_rights
// Uname
type Utsname C.struct_utsname

View file

@ -37,6 +37,7 @@ package unix
#include <sys/time.h> #include <sys/time.h>
#include <sys/uio.h> #include <sys/uio.h>
#include <sys/un.h> #include <sys/un.h>
#include <sys/utsname.h>
#include <sys/wait.h> #include <sys/wait.h>
#include <net/bpf.h> #include <net/bpf.h>
#include <net/if.h> #include <net/if.h>
@ -111,6 +112,12 @@ type Dirent C.struct_dirent
type Fsid C.fsid_t type Fsid C.fsid_t
// File system limits
const (
PathMax = C.PATH_MAX
)
// Sockets // Sockets
type RawSockaddrInet4 C.struct_sockaddr_in type RawSockaddrInet4 C.struct_sockaddr_in
@ -257,3 +264,7 @@ const (
// Sysctl // Sysctl
type Sysctlnode C.struct_sysctlnode type Sysctlnode C.struct_sysctlnode
// Uname
type Utsname C.struct_utsname

View file

@ -36,6 +36,7 @@ package unix
#include <sys/time.h> #include <sys/time.h>
#include <sys/uio.h> #include <sys/uio.h>
#include <sys/un.h> #include <sys/un.h>
#include <sys/utsname.h>
#include <sys/wait.h> #include <sys/wait.h>
#include <net/bpf.h> #include <net/bpf.h>
#include <net/if.h> #include <net/if.h>
@ -127,6 +128,12 @@ type Dirent C.struct_dirent
type Fsid C.fsid_t type Fsid C.fsid_t
// File system limits
const (
PathMax = C.PATH_MAX
)
// Sockets // Sockets
type RawSockaddrInet4 C.struct_sockaddr_in type RawSockaddrInet4 C.struct_sockaddr_in
@ -269,3 +276,7 @@ const (
POLLWRBAND = C.POLLWRBAND POLLWRBAND = C.POLLWRBAND
POLLWRNORM = C.POLLWRNORM POLLWRNORM = C.POLLWRNORM
) )
// Uname
type Utsname C.struct_utsname

View file

@ -168,6 +168,8 @@ const (
CSTOP = 0x13 CSTOP = 0x13
CSTOPB = 0x400 CSTOPB = 0x400
CSUSP = 0x1a CSUSP = 0x1a
CTL_HW = 0x6
CTL_KERN = 0x1
CTL_MAXNAME = 0xc CTL_MAXNAME = 0xc
CTL_NET = 0x4 CTL_NET = 0x4
DLT_A429 = 0xb8 DLT_A429 = 0xb8
@ -353,6 +355,7 @@ const (
F_UNLCK = 0x2 F_UNLCK = 0x2
F_WRLCK = 0x3 F_WRLCK = 0x3
HUPCL = 0x4000 HUPCL = 0x4000
HW_MACHINE = 0x1
ICANON = 0x100 ICANON = 0x100
ICMP6_FILTER = 0x12 ICMP6_FILTER = 0x12
ICRNL = 0x100 ICRNL = 0x100
@ -835,6 +838,10 @@ const (
IXANY = 0x800 IXANY = 0x800
IXOFF = 0x400 IXOFF = 0x400
IXON = 0x200 IXON = 0x200
KERN_HOSTNAME = 0xa
KERN_OSRELEASE = 0x2
KERN_OSTYPE = 0x1
KERN_VERSION = 0x4
LOCK_EX = 0x2 LOCK_EX = 0x2
LOCK_NB = 0x4 LOCK_NB = 0x4
LOCK_SH = 0x1 LOCK_SH = 0x1

View file

@ -351,6 +351,8 @@ const (
CSTOP = 0x13 CSTOP = 0x13
CSTOPB = 0x400 CSTOPB = 0x400
CSUSP = 0x1a CSUSP = 0x1a
CTL_HW = 0x6
CTL_KERN = 0x1
CTL_MAXNAME = 0x18 CTL_MAXNAME = 0x18
CTL_NET = 0x4 CTL_NET = 0x4
DLT_A429 = 0xb8 DLT_A429 = 0xb8
@ -608,6 +610,7 @@ const (
F_UNLCKSYS = 0x4 F_UNLCKSYS = 0x4
F_WRLCK = 0x3 F_WRLCK = 0x3
HUPCL = 0x4000 HUPCL = 0x4000
HW_MACHINE = 0x1
ICANON = 0x100 ICANON = 0x100
ICMP6_FILTER = 0x12 ICMP6_FILTER = 0x12
ICRNL = 0x100 ICRNL = 0x100
@ -944,6 +947,10 @@ const (
IXANY = 0x800 IXANY = 0x800
IXOFF = 0x400 IXOFF = 0x400
IXON = 0x200 IXON = 0x200
KERN_HOSTNAME = 0xa
KERN_OSRELEASE = 0x2
KERN_OSTYPE = 0x1
KERN_VERSION = 0x4
LOCK_EX = 0x2 LOCK_EX = 0x2
LOCK_NB = 0x4 LOCK_NB = 0x4
LOCK_SH = 0x1 LOCK_SH = 0x1

View file

@ -351,6 +351,8 @@ const (
CSTOP = 0x13 CSTOP = 0x13
CSTOPB = 0x400 CSTOPB = 0x400
CSUSP = 0x1a CSUSP = 0x1a
CTL_HW = 0x6
CTL_KERN = 0x1
CTL_MAXNAME = 0x18 CTL_MAXNAME = 0x18
CTL_NET = 0x4 CTL_NET = 0x4
DLT_A429 = 0xb8 DLT_A429 = 0xb8
@ -608,6 +610,7 @@ const (
F_UNLCKSYS = 0x4 F_UNLCKSYS = 0x4
F_WRLCK = 0x3 F_WRLCK = 0x3
HUPCL = 0x4000 HUPCL = 0x4000
HW_MACHINE = 0x1
ICANON = 0x100 ICANON = 0x100
ICMP6_FILTER = 0x12 ICMP6_FILTER = 0x12
ICRNL = 0x100 ICRNL = 0x100
@ -944,6 +947,10 @@ const (
IXANY = 0x800 IXANY = 0x800
IXOFF = 0x400 IXOFF = 0x400
IXON = 0x200 IXON = 0x200
KERN_HOSTNAME = 0xa
KERN_OSRELEASE = 0x2
KERN_OSTYPE = 0x1
KERN_VERSION = 0x4
LOCK_EX = 0x2 LOCK_EX = 0x2
LOCK_NB = 0x4 LOCK_NB = 0x4
LOCK_SH = 0x1 LOCK_SH = 0x1

View file

@ -351,6 +351,8 @@ const (
CSTOP = 0x13 CSTOP = 0x13
CSTOPB = 0x400 CSTOPB = 0x400
CSUSP = 0x1a CSUSP = 0x1a
CTL_HW = 0x6
CTL_KERN = 0x1
CTL_MAXNAME = 0x18 CTL_MAXNAME = 0x18
CTL_NET = 0x4 CTL_NET = 0x4
DLT_A429 = 0xb8 DLT_A429 = 0xb8
@ -615,6 +617,7 @@ const (
F_UNLCKSYS = 0x4 F_UNLCKSYS = 0x4
F_WRLCK = 0x3 F_WRLCK = 0x3
HUPCL = 0x4000 HUPCL = 0x4000
HW_MACHINE = 0x1
ICANON = 0x100 ICANON = 0x100
ICMP6_FILTER = 0x12 ICMP6_FILTER = 0x12
ICRNL = 0x100 ICRNL = 0x100
@ -951,6 +954,10 @@ const (
IXANY = 0x800 IXANY = 0x800
IXOFF = 0x400 IXOFF = 0x400
IXON = 0x200 IXON = 0x200
KERN_HOSTNAME = 0xa
KERN_OSRELEASE = 0x2
KERN_OSTYPE = 0x1
KERN_VERSION = 0x4
LOCK_EX = 0x2 LOCK_EX = 0x2
LOCK_NB = 0x4 LOCK_NB = 0x4
LOCK_SH = 0x1 LOCK_SH = 0x1

View file

@ -1,5 +1,5 @@
// mkerrors.sh -m32 // mkerrors.sh -m32
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT // Code generated by the command above; see README.md. DO NOT EDIT.
// +build 386,netbsd // +build 386,netbsd
@ -169,6 +169,8 @@ const (
CSTOP = 0x13 CSTOP = 0x13
CSTOPB = 0x400 CSTOPB = 0x400
CSUSP = 0x1a CSUSP = 0x1a
CTL_HW = 0x6
CTL_KERN = 0x1
CTL_MAXNAME = 0xc CTL_MAXNAME = 0xc
CTL_NET = 0x4 CTL_NET = 0x4
CTL_QUERY = -0x2 CTL_QUERY = -0x2
@ -581,6 +583,7 @@ const (
F_UNLCK = 0x2 F_UNLCK = 0x2
F_WRLCK = 0x3 F_WRLCK = 0x3
HUPCL = 0x4000 HUPCL = 0x4000
HW_MACHINE = 0x1
ICANON = 0x100 ICANON = 0x100
ICMP6_FILTER = 0x12 ICMP6_FILTER = 0x12
ICRNL = 0x100 ICRNL = 0x100
@ -970,6 +973,10 @@ const (
IXANY = 0x800 IXANY = 0x800
IXOFF = 0x400 IXOFF = 0x400
IXON = 0x200 IXON = 0x200
KERN_HOSTNAME = 0xa
KERN_OSRELEASE = 0x2
KERN_OSTYPE = 0x1
KERN_VERSION = 0x4
LOCK_EX = 0x2 LOCK_EX = 0x2
LOCK_NB = 0x4 LOCK_NB = 0x4
LOCK_SH = 0x1 LOCK_SH = 0x1

Some files were not shown because too many files have changed in this diff Show more