Commit 9ff56c95 by Ian Lance Taylor

Update to current version of Go library.

From-SVN: r173931
parent 37cb25ed
...@@ -1432,7 +1432,7 @@ Type::methods_constructor(Gogo* gogo, Type* methods_type, ...@@ -1432,7 +1432,7 @@ Type::methods_constructor(Gogo* gogo, Type* methods_type,
p != smethods.end(); p != smethods.end();
++p) ++p)
vals->push_back(this->method_constructor(gogo, method_type, p->first, vals->push_back(this->method_constructor(gogo, method_type, p->first,
p->second)); p->second, only_value_methods));
return Expression::make_slice_composite_literal(methods_type, vals, bloc); return Expression::make_slice_composite_literal(methods_type, vals, bloc);
} }
...@@ -1444,7 +1444,8 @@ Type::methods_constructor(Gogo* gogo, Type* methods_type, ...@@ -1444,7 +1444,8 @@ Type::methods_constructor(Gogo* gogo, Type* methods_type,
Expression* Expression*
Type::method_constructor(Gogo*, Type* method_type, Type::method_constructor(Gogo*, Type* method_type,
const std::string& method_name, const std::string& method_name,
const Method* m) const const Method* m,
bool only_value_methods) const
{ {
source_location bloc = BUILTINS_LOCATION; source_location bloc = BUILTINS_LOCATION;
...@@ -1487,6 +1488,25 @@ Type::method_constructor(Gogo*, Type* method_type, ...@@ -1487,6 +1488,25 @@ Type::method_constructor(Gogo*, Type* method_type,
++p; ++p;
go_assert(p->field_name() == "typ"); go_assert(p->field_name() == "typ");
if (!only_value_methods && m->is_value_method())
{
// This is a value method on a pointer type. Change the type of
// the method to use a pointer receiver. The implementation
// always uses a pointer receiver anyhow.
Type* rtype = mtype->receiver()->type();
Type* prtype = Type::make_pointer_type(rtype);
Typed_identifier* receiver =
new Typed_identifier(mtype->receiver()->name(), prtype,
mtype->receiver()->location());
mtype = Type::make_function_type(receiver,
(mtype->parameters() == NULL
? NULL
: mtype->parameters()->copy()),
(mtype->results() == NULL
? NULL
: mtype->results()->copy()),
mtype->location());
}
vals->push_back(Expression::make_type_descriptor(mtype, bloc)); vals->push_back(Expression::make_type_descriptor(mtype, bloc));
++p; ++p;
...@@ -2779,14 +2799,7 @@ Function_type::type_descriptor_params(Type* params_type, ...@@ -2779,14 +2799,7 @@ Function_type::type_descriptor_params(Type* params_type,
+ (receiver != NULL ? 1 : 0)); + (receiver != NULL ? 1 : 0));
if (receiver != NULL) if (receiver != NULL)
{ vals->push_back(Expression::make_type_descriptor(receiver->type(), bloc));
Type* rtype = receiver->type();
// The receiver is always passed as a pointer. FIXME: Is this
// right? Should that fact affect the type descriptor?
if (rtype->points_to() == NULL)
rtype = Type::make_pointer_type(rtype);
vals->push_back(Expression::make_type_descriptor(rtype, bloc));
}
if (params != NULL) if (params != NULL)
{ {
...@@ -4822,9 +4835,10 @@ Array_type::make_array_type_descriptor_type() ...@@ -4822,9 +4835,10 @@ Array_type::make_array_type_descriptor_type()
Type* uintptr_type = Type::lookup_integer_type("uintptr"); Type* uintptr_type = Type::lookup_integer_type("uintptr");
Struct_type* sf = Struct_type* sf =
Type::make_builtin_struct_type(3, Type::make_builtin_struct_type(4,
"", tdt, "", tdt,
"elem", ptdt, "elem", ptdt,
"slice", ptdt,
"len", uintptr_type); "len", uintptr_type);
ret = Type::make_builtin_named_type("ArrayType", sf); ret = Type::make_builtin_named_type("ArrayType", sf);
...@@ -4891,6 +4905,11 @@ Array_type::array_type_descriptor(Gogo* gogo, Named_type* name) ...@@ -4891,6 +4905,11 @@ Array_type::array_type_descriptor(Gogo* gogo, Named_type* name)
vals->push_back(Expression::make_type_descriptor(this->element_type_, bloc)); vals->push_back(Expression::make_type_descriptor(this->element_type_, bloc));
++p; ++p;
go_assert(p->field_name() == "slice");
Type* slice_type = Type::make_array_type(this->element_type_, NULL);
vals->push_back(Expression::make_type_descriptor(slice_type, bloc));
++p;
go_assert(p->field_name() == "len"); go_assert(p->field_name() == "len");
vals->push_back(Expression::make_cast(p->type(), this->length_, bloc)); vals->push_back(Expression::make_cast(p->type(), this->length_, bloc));
...@@ -5375,8 +5394,9 @@ Channel_type::do_make_expression_tree(Translate_context* context, ...@@ -5375,8 +5394,9 @@ Channel_type::do_make_expression_tree(Translate_context* context,
Gogo* gogo = context->gogo(); Gogo* gogo = context->gogo();
tree channel_type = type_to_tree(this->get_backend(gogo)); tree channel_type = type_to_tree(this->get_backend(gogo));
tree element_tree = type_to_tree(this->element_type_->get_backend(gogo)); Type* ptdt = Type::make_type_descriptor_ptr_type();
tree element_size_tree = size_in_bytes(element_tree); tree element_type_descriptor =
this->element_type_->type_descriptor_pointer(gogo);
tree bad_index = NULL_TREE; tree bad_index = NULL_TREE;
...@@ -5402,8 +5422,8 @@ Channel_type::do_make_expression_tree(Translate_context* context, ...@@ -5402,8 +5422,8 @@ Channel_type::do_make_expression_tree(Translate_context* context,
"__go_new_channel", "__go_new_channel",
2, 2,
channel_type, channel_type,
sizetype, type_to_tree(ptdt->get_backend(gogo)),
element_size_tree, element_type_descriptor,
sizetype, sizetype,
expr_tree); expr_tree);
if (ret == error_mark_node) if (ret == error_mark_node)
...@@ -6242,7 +6262,16 @@ Interface_type::do_reflection(Gogo* gogo, std::string* ret) const ...@@ -6242,7 +6262,16 @@ Interface_type::do_reflection(Gogo* gogo, std::string* ret) const
if (p != this->methods_->begin()) if (p != this->methods_->begin())
ret->append(";"); ret->append(";");
ret->push_back(' '); ret->push_back(' ');
ret->append(Gogo::unpack_hidden_name(p->name())); if (!Gogo::is_hidden_name(p->name()))
ret->append(p->name());
else
{
// This matches what the gc compiler does.
std::string prefix = Gogo::hidden_name_prefix(p->name());
ret->append(prefix.substr(prefix.find('.') + 1));
ret->push_back('.');
ret->append(Gogo::unpack_hidden_name(p->name()));
}
std::string sub = p->type()->reflection(gogo); std::string sub = p->type()->reflection(gogo);
go_assert(sub.compare(0, 4, "func") == 0); go_assert(sub.compare(0, 4, "func") == 0);
sub = sub.substr(4); sub = sub.substr(4);
......
...@@ -1044,7 +1044,7 @@ class Type ...@@ -1044,7 +1044,7 @@ class Type
// Build a composite literal for one method. // Build a composite literal for one method.
Expression* Expression*
method_constructor(Gogo*, Type* method_type, const std::string& name, method_constructor(Gogo*, Type* method_type, const std::string& name,
const Method*) const; const Method*, bool only_value_methods) const;
static tree static tree
build_receive_return_type(tree type); build_receive_return_type(tree type);
......
...@@ -8,7 +8,7 @@ package main ...@@ -8,7 +8,7 @@ package main
import "reflect" import "reflect"
func typeof(x interface{}) string { return reflect.Typeof(x).String() } func typeof(x interface{}) string { return reflect.TypeOf(x).String() }
func f() int { return 0 } func f() int { return 0 }
......
...@@ -5,23 +5,26 @@ ...@@ -5,23 +5,26 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package main package main
import "reflect" import "reflect"
type S1 struct { i int }
type S2 struct { S1 } type S1 struct{ i int }
type S2 struct{ S1 }
func main() { func main() {
typ := reflect.Typeof(S2{}).(*reflect.StructType); typ := reflect.TypeOf(S2{})
f := typ.Field(0); f := typ.Field(0)
if f.Name != "S1" || f.Anonymous != true { if f.Name != "S1" || f.Anonymous != true {
println("BUG: ", f.Name, f.Anonymous); println("BUG: ", f.Name, f.Anonymous)
return; return
} }
f, ok := typ.FieldByName("S1"); f, ok := typ.FieldByName("S1")
if !ok { if !ok {
println("BUG: missing S1"); println("BUG: missing S1")
return; return
} }
if !f.Anonymous { if !f.Anonymous {
println("BUG: S1 is not anonymous"); println("BUG: S1 is not anonymous")
return; return
} }
} }
...@@ -38,11 +38,11 @@ func main() { ...@@ -38,11 +38,11 @@ func main() {
// meaning that reflect data for v0, v1 didn't get confused. // meaning that reflect data for v0, v1 didn't get confused.
// path is full (rooted) path name. check suffix for gc, prefix for gccgo // path is full (rooted) path name. check suffix for gc, prefix for gccgo
if s := reflect.Typeof(v0).PkgPath(); !strings.HasSuffix(s, "/bug0") && !strings.HasPrefix(s, "bug0") { if s := reflect.TypeOf(v0).PkgPath(); !strings.HasSuffix(s, "/bug0") && !strings.HasPrefix(s, "bug0") {
println("bad v0 path", len(s), s) println("bad v0 path", len(s), s)
panic("fail") panic("fail")
} }
if s := reflect.Typeof(v1).PkgPath(); !strings.HasSuffix(s, "/bug1") && !strings.HasPrefix(s, "bug1") { if s := reflect.TypeOf(v1).PkgPath(); !strings.HasSuffix(s, "/bug1") && !strings.HasPrefix(s, "bug1") {
println("bad v1 path", s) println("bad v1 path", s)
panic("fail") panic("fail")
} }
......
...@@ -46,34 +46,34 @@ func main() { ...@@ -46,34 +46,34 @@ func main() {
x.t = add("abc", "def") x.t = add("abc", "def")
x.u = 1 x.u = 1
x.v = 2 x.v = 2
x.w = 1<<28 x.w = 1 << 28
x.x = 2<<28 x.x = 2 << 28
x.y = 0x12345678 x.y = 0x12345678
x.z = x.y x.z = x.y
// check mem and string // check mem and string
v := reflect.NewValue(x) v := reflect.ValueOf(x)
i := v.(*reflect.StructValue).Field(0) i := v.Field(0)
j := v.(*reflect.StructValue).Field(1) j := v.Field(1)
assert(i.Interface() == j.Interface()) assert(i.Interface() == j.Interface())
s := v.(*reflect.StructValue).Field(2) s := v.Field(2)
t := v.(*reflect.StructValue).Field(3) t := v.Field(3)
assert(s.Interface() == t.Interface()) assert(s.Interface() == t.Interface())
// make sure different values are different. // make sure different values are different.
// make sure whole word is being compared, // make sure whole word is being compared,
// not just a single byte. // not just a single byte.
i = v.(*reflect.StructValue).Field(4) i = v.Field(4)
j = v.(*reflect.StructValue).Field(5) j = v.Field(5)
assert(i.Interface() != j.Interface()) assert(i.Interface() != j.Interface())
i = v.(*reflect.StructValue).Field(6) i = v.Field(6)
j = v.(*reflect.StructValue).Field(7) j = v.Field(7)
assert(i.Interface() != j.Interface()) assert(i.Interface() != j.Interface())
i = v.(*reflect.StructValue).Field(8) i = v.Field(8)
j = v.(*reflect.StructValue).Field(9) j = v.Field(9)
assert(i.Interface() == j.Interface()) assert(i.Interface() == j.Interface())
} }
......
...@@ -25,9 +25,9 @@ func main() { ...@@ -25,9 +25,9 @@ func main() {
println(c) println(c)
var a interface{} var a interface{}
switch c := reflect.NewValue(a).(type) { switch c := reflect.ValueOf(a); c.Kind() {
case *reflect.ComplexValue: case reflect.Complex64, reflect.Complex128:
v := c.Get() v := c.Complex()
_, _ = complex128(v), true _, _ = complex128(v), true
} }
} }
f618e5e0991d aea0ba6e5935
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The tar package implements access to tar archives. // Package tar implements access to tar archives.
// It aims to cover most of the variations, including those produced // It aims to cover most of the variations, including those produced
// by GNU and BSD tars. // by GNU and BSD tars.
// //
......
...@@ -10,6 +10,7 @@ package tar ...@@ -10,6 +10,7 @@ package tar
import ( import (
"bytes" "bytes"
"io" "io"
"io/ioutil"
"os" "os"
"strconv" "strconv"
) )
...@@ -27,13 +28,13 @@ var ( ...@@ -27,13 +28,13 @@ var (
// tr := tar.NewReader(r) // tr := tar.NewReader(r)
// for { // for {
// hdr, err := tr.Next() // hdr, err := tr.Next()
// if err != nil { // if err == os.EOF {
// // handle error
// }
// if hdr == nil {
// // end of tar archive // // end of tar archive
// break // break
// } // }
// if err != nil {
// // handle error
// }
// io.Copy(data, tr) // io.Copy(data, tr)
// } // }
type Reader struct { type Reader struct {
...@@ -84,12 +85,6 @@ func (tr *Reader) octal(b []byte) int64 { ...@@ -84,12 +85,6 @@ func (tr *Reader) octal(b []byte) int64 {
return int64(x) return int64(x)
} }
type ignoreWriter struct{}
func (ignoreWriter) Write(b []byte) (n int, err os.Error) {
return len(b), nil
}
// Skip any unread bytes in the existing file entry, as well as any alignment padding. // Skip any unread bytes in the existing file entry, as well as any alignment padding.
func (tr *Reader) skipUnread() { func (tr *Reader) skipUnread() {
nr := tr.nb + tr.pad // number of bytes to skip nr := tr.nb + tr.pad // number of bytes to skip
...@@ -99,7 +94,7 @@ func (tr *Reader) skipUnread() { ...@@ -99,7 +94,7 @@ func (tr *Reader) skipUnread() {
return return
} }
} }
_, tr.err = io.Copyn(ignoreWriter{}, tr.r, nr) _, tr.err = io.Copyn(ioutil.Discard, tr.r, nr)
} }
func (tr *Reader) verifyChecksum(header []byte) bool { func (tr *Reader) verifyChecksum(header []byte) bool {
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
/* /*
The zip package provides support for reading ZIP archives. Package zip provides support for reading ZIP archives.
See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT
...@@ -35,6 +35,11 @@ type Reader struct { ...@@ -35,6 +35,11 @@ type Reader struct {
Comment string Comment string
} }
type ReadCloser struct {
f *os.File
Reader
}
type File struct { type File struct {
FileHeader FileHeader
zipr io.ReaderAt zipr io.ReaderAt
...@@ -47,43 +52,60 @@ func (f *File) hasDataDescriptor() bool { ...@@ -47,43 +52,60 @@ func (f *File) hasDataDescriptor() bool {
return f.Flags&0x8 != 0 return f.Flags&0x8 != 0
} }
// OpenReader will open the Zip file specified by name and return a Reader. // OpenReader will open the Zip file specified by name and return a ReaderCloser.
func OpenReader(name string) (*Reader, os.Error) { func OpenReader(name string) (*ReadCloser, os.Error) {
f, err := os.Open(name) f, err := os.Open(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
fi, err := f.Stat() fi, err := f.Stat()
if err != nil { if err != nil {
f.Close()
return nil, err
}
r := new(ReadCloser)
if err := r.init(f, fi.Size); err != nil {
f.Close()
return nil, err return nil, err
} }
return NewReader(f, fi.Size) return r, nil
} }
// NewReader returns a new Reader reading from r, which is assumed to // NewReader returns a new Reader reading from r, which is assumed to
// have the given size in bytes. // have the given size in bytes.
func NewReader(r io.ReaderAt, size int64) (*Reader, os.Error) { func NewReader(r io.ReaderAt, size int64) (*Reader, os.Error) {
end, err := readDirectoryEnd(r, size) zr := new(Reader)
if err != nil { if err := zr.init(r, size); err != nil {
return nil, err return nil, err
} }
z := &Reader{ return zr, nil
r: r, }
File: make([]*File, end.directoryRecords),
Comment: end.comment, func (z *Reader) init(r io.ReaderAt, size int64) os.Error {
end, err := readDirectoryEnd(r, size)
if err != nil {
return err
} }
z.r = r
z.File = make([]*File, end.directoryRecords)
z.Comment = end.comment
rs := io.NewSectionReader(r, 0, size) rs := io.NewSectionReader(r, 0, size)
if _, err = rs.Seek(int64(end.directoryOffset), os.SEEK_SET); err != nil { if _, err = rs.Seek(int64(end.directoryOffset), os.SEEK_SET); err != nil {
return nil, err return err
} }
buf := bufio.NewReader(rs) buf := bufio.NewReader(rs)
for i := range z.File { for i := range z.File {
z.File[i] = &File{zipr: r, zipsize: size} z.File[i] = &File{zipr: r, zipsize: size}
if err := readDirectoryHeader(z.File[i], buf); err != nil { if err := readDirectoryHeader(z.File[i], buf); err != nil {
return nil, err return err
} }
} }
return z, nil return nil
}
// Close closes the Zip file, rendering it unusable for I/O.
func (rc *ReadCloser) Close() os.Error {
return rc.f.Close()
} }
// Open returns a ReadCloser that provides access to the File's contents. // Open returns a ReadCloser that provides access to the File's contents.
......
...@@ -76,6 +76,12 @@ func readTestZip(t *testing.T, zt ZipTest) { ...@@ -76,6 +76,12 @@ func readTestZip(t *testing.T, zt ZipTest) {
return return
} }
// bail if file is not zip
if err == FormatError {
return
}
defer z.Close()
// bail here if no Files expected to be tested // bail here if no Files expected to be tested
// (there may actually be files in the zip, but we don't care) // (there may actually be files in the zip, but we don't care)
if zt.File == nil { if zt.File == nil {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The asn1 package implements parsing of DER-encoded ASN.1 data structures, // Package asn1 implements parsing of DER-encoded ASN.1 data structures,
// as defined in ITU-T Rec X.690. // as defined in ITU-T Rec X.690.
// //
// See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,'' // See also ``A Layman's Guide to a Subset of ASN.1, BER, and DER,''
...@@ -373,7 +373,7 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i ...@@ -373,7 +373,7 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i
// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse // parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
// a number of ASN.1 values from the given byte array and returns them as a // a number of ASN.1 values from the given byte array and returns them as a
// slice of Go values of the given type. // slice of Go values of the given type.
func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflect.Type) (ret *reflect.SliceValue, err os.Error) { func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) {
expectedTag, compoundType, ok := getUniversalType(elemType) expectedTag, compoundType, ok := getUniversalType(elemType)
if !ok { if !ok {
err = StructuralError{"unknown Go type for slice"} err = StructuralError{"unknown Go type for slice"}
...@@ -409,7 +409,7 @@ func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflec ...@@ -409,7 +409,7 @@ func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflec
params := fieldParameters{} params := fieldParameters{}
offset := 0 offset := 0
for i := 0; i < numElements; i++ { for i := 0; i < numElements; i++ {
offset, err = parseField(ret.Elem(i), bytes, offset, params) offset, err = parseField(ret.Index(i), bytes, offset, params)
if err != nil { if err != nil {
return return
} }
...@@ -418,13 +418,13 @@ func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflec ...@@ -418,13 +418,13 @@ func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflec
} }
var ( var (
bitStringType = reflect.Typeof(BitString{}) bitStringType = reflect.TypeOf(BitString{})
objectIdentifierType = reflect.Typeof(ObjectIdentifier{}) objectIdentifierType = reflect.TypeOf(ObjectIdentifier{})
enumeratedType = reflect.Typeof(Enumerated(0)) enumeratedType = reflect.TypeOf(Enumerated(0))
flagType = reflect.Typeof(Flag(false)) flagType = reflect.TypeOf(Flag(false))
timeType = reflect.Typeof(&time.Time{}) timeType = reflect.TypeOf(&time.Time{})
rawValueType = reflect.Typeof(RawValue{}) rawValueType = reflect.TypeOf(RawValue{})
rawContentsType = reflect.Typeof(RawContent(nil)) rawContentsType = reflect.TypeOf(RawContent(nil))
) )
// invalidLength returns true iff offset + length > sliceLength, or if the // invalidLength returns true iff offset + length > sliceLength, or if the
...@@ -461,13 +461,12 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -461,13 +461,12 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
} }
result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]} result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]}
offset += t.length offset += t.length
v.(*reflect.StructValue).Set(reflect.NewValue(result).(*reflect.StructValue)) v.Set(reflect.ValueOf(result))
return return
} }
// Deal with the ANY type. // Deal with the ANY type.
if ifaceType, ok := fieldType.(*reflect.InterfaceType); ok && ifaceType.NumMethod() == 0 { if ifaceType := fieldType; ifaceType.Kind() == reflect.Interface && ifaceType.NumMethod() == 0 {
ifaceValue := v.(*reflect.InterfaceValue)
var t tagAndLength var t tagAndLength
t, offset, err = parseTagAndLength(bytes, offset) t, offset, err = parseTagAndLength(bytes, offset)
if err != nil { if err != nil {
...@@ -506,7 +505,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -506,7 +505,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
return return
} }
if result != nil { if result != nil {
ifaceValue.Set(reflect.NewValue(result)) v.Set(reflect.ValueOf(result))
} }
return return
} }
...@@ -536,9 +535,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -536,9 +535,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
err = StructuralError{"Zero length explicit tag was not an asn1.Flag"} err = StructuralError{"Zero length explicit tag was not an asn1.Flag"}
return return
} }
v.SetBool(true)
flagValue := v.(*reflect.BoolValue)
flagValue.Set(true)
return return
} }
} else { } else {
...@@ -606,23 +603,20 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -606,23 +603,20 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
switch fieldType { switch fieldType {
case objectIdentifierType: case objectIdentifierType:
newSlice, err1 := parseObjectIdentifier(innerBytes) newSlice, err1 := parseObjectIdentifier(innerBytes)
sliceValue := v.(*reflect.SliceValue) v.Set(reflect.MakeSlice(v.Type(), len(newSlice), len(newSlice)))
sliceValue.Set(reflect.MakeSlice(sliceValue.Type().(*reflect.SliceType), len(newSlice), len(newSlice)))
if err1 == nil { if err1 == nil {
reflect.Copy(sliceValue, reflect.NewValue(newSlice).(reflect.ArrayOrSliceValue)) reflect.Copy(v, reflect.ValueOf(newSlice))
} }
err = err1 err = err1
return return
case bitStringType: case bitStringType:
structValue := v.(*reflect.StructValue)
bs, err1 := parseBitString(innerBytes) bs, err1 := parseBitString(innerBytes)
if err1 == nil { if err1 == nil {
structValue.Set(reflect.NewValue(bs).(*reflect.StructValue)) v.Set(reflect.ValueOf(bs))
} }
err = err1 err = err1
return return
case timeType: case timeType:
ptrValue := v.(*reflect.PtrValue)
var time *time.Time var time *time.Time
var err1 os.Error var err1 os.Error
if universalTag == tagUTCTime { if universalTag == tagUTCTime {
...@@ -631,55 +625,53 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -631,55 +625,53 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
time, err1 = parseGeneralizedTime(innerBytes) time, err1 = parseGeneralizedTime(innerBytes)
} }
if err1 == nil { if err1 == nil {
ptrValue.Set(reflect.NewValue(time).(*reflect.PtrValue)) v.Set(reflect.ValueOf(time))
} }
err = err1 err = err1
return return
case enumeratedType: case enumeratedType:
parsedInt, err1 := parseInt(innerBytes) parsedInt, err1 := parseInt(innerBytes)
enumValue := v.(*reflect.IntValue)
if err1 == nil { if err1 == nil {
enumValue.Set(int64(parsedInt)) v.SetInt(int64(parsedInt))
} }
err = err1 err = err1
return return
case flagType: case flagType:
flagValue := v.(*reflect.BoolValue) v.SetBool(true)
flagValue.Set(true)
return return
} }
switch val := v.(type) { switch val := v; val.Kind() {
case *reflect.BoolValue: case reflect.Bool:
parsedBool, err1 := parseBool(innerBytes) parsedBool, err1 := parseBool(innerBytes)
if err1 == nil { if err1 == nil {
val.Set(parsedBool) val.SetBool(parsedBool)
} }
err = err1 err = err1
return return
case *reflect.IntValue: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch val.Type().Kind() { switch val.Type().Kind() {
case reflect.Int: case reflect.Int:
parsedInt, err1 := parseInt(innerBytes) parsedInt, err1 := parseInt(innerBytes)
if err1 == nil { if err1 == nil {
val.Set(int64(parsedInt)) val.SetInt(int64(parsedInt))
} }
err = err1 err = err1
return return
case reflect.Int64: case reflect.Int64:
parsedInt, err1 := parseInt64(innerBytes) parsedInt, err1 := parseInt64(innerBytes)
if err1 == nil { if err1 == nil {
val.Set(parsedInt) val.SetInt(parsedInt)
} }
err = err1 err = err1
return return
} }
case *reflect.StructValue: case reflect.Struct:
structType := fieldType.(*reflect.StructType) structType := fieldType
if structType.NumField() > 0 && if structType.NumField() > 0 &&
structType.Field(0).Type == rawContentsType { structType.Field(0).Type == rawContentsType {
bytes := bytes[initOffset:offset] bytes := bytes[initOffset:offset]
val.Field(0).SetValue(reflect.NewValue(RawContent(bytes))) val.Field(0).Set(reflect.ValueOf(RawContent(bytes)))
} }
innerOffset := 0 innerOffset := 0
...@@ -697,11 +689,11 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -697,11 +689,11 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
// adding elements to the end has been used in X.509 as the // adding elements to the end has been used in X.509 as the
// version numbers have increased. // version numbers have increased.
return return
case *reflect.SliceValue: case reflect.Slice:
sliceType := fieldType.(*reflect.SliceType) sliceType := fieldType
if sliceType.Elem().Kind() == reflect.Uint8 { if sliceType.Elem().Kind() == reflect.Uint8 {
val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes))) val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes)))
reflect.Copy(val, reflect.NewValue(innerBytes).(reflect.ArrayOrSliceValue)) reflect.Copy(val, reflect.ValueOf(innerBytes))
return return
} }
newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem()) newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem())
...@@ -710,7 +702,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -710,7 +702,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
} }
err = err1 err = err1
return return
case *reflect.StringValue: case reflect.String:
var v string var v string
switch universalTag { switch universalTag {
case tagPrintableString: case tagPrintableString:
...@@ -729,7 +721,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -729,7 +721,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)} err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)}
} }
if err == nil { if err == nil {
val.Set(v) val.SetString(v)
} }
return return
} }
...@@ -748,9 +740,9 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) { ...@@ -748,9 +740,9 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
if params.defaultValue == nil { if params.defaultValue == nil {
return return
} }
switch val := v.(type) { switch val := v; val.Kind() {
case *reflect.IntValue: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
val.Set(*params.defaultValue) val.SetInt(*params.defaultValue)
} }
return return
} }
...@@ -806,7 +798,7 @@ func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) { ...@@ -806,7 +798,7 @@ func Unmarshal(b []byte, val interface{}) (rest []byte, err os.Error) {
// UnmarshalWithParams allows field parameters to be specified for the // UnmarshalWithParams allows field parameters to be specified for the
// top-level element. The form of the params is the same as the field tags. // top-level element. The form of the params is the same as the field tags.
func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) { func UnmarshalWithParams(b []byte, val interface{}, params string) (rest []byte, err os.Error) {
v := reflect.NewValue(val).(*reflect.PtrValue).Elem() v := reflect.ValueOf(val).Elem()
offset, err := parseField(v, b, 0, parseFieldParameters(params)) offset, err := parseField(v, b, 0, parseFieldParameters(params))
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -267,11 +267,6 @@ func TestParseFieldParameters(t *testing.T) { ...@@ -267,11 +267,6 @@ func TestParseFieldParameters(t *testing.T) {
} }
} }
type unmarshalTest struct {
in []byte
out interface{}
}
type TestObjectIdentifierStruct struct { type TestObjectIdentifierStruct struct {
OID ObjectIdentifier OID ObjectIdentifier
} }
...@@ -290,7 +285,10 @@ type TestElementsAfterString struct { ...@@ -290,7 +285,10 @@ type TestElementsAfterString struct {
A, B int A, B int
} }
var unmarshalTestData []unmarshalTest = []unmarshalTest{ var unmarshalTestData = []struct {
in []byte
out interface{}
}{
{[]byte{0x02, 0x01, 0x42}, newInt(0x42)}, {[]byte{0x02, 0x01, 0x42}, newInt(0x42)},
{[]byte{0x30, 0x08, 0x06, 0x06, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d}, &TestObjectIdentifierStruct{[]int{1, 2, 840, 113549}}}, {[]byte{0x30, 0x08, 0x06, 0x06, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d}, &TestObjectIdentifierStruct{[]int{1, 2, 840, 113549}}},
{[]byte{0x03, 0x04, 0x06, 0x6e, 0x5d, 0xc0}, &BitString{[]byte{110, 93, 192}, 18}}, {[]byte{0x03, 0x04, 0x06, 0x6e, 0x5d, 0xc0}, &BitString{[]byte{110, 93, 192}, 18}},
...@@ -309,9 +307,7 @@ var unmarshalTestData []unmarshalTest = []unmarshalTest{ ...@@ -309,9 +307,7 @@ var unmarshalTestData []unmarshalTest = []unmarshalTest{
func TestUnmarshal(t *testing.T) { func TestUnmarshal(t *testing.T) {
for i, test := range unmarshalTestData { for i, test := range unmarshalTestData {
pv := reflect.MakeZero(reflect.NewValue(test.out).Type()) pv := reflect.New(reflect.TypeOf(test.out).Elem())
zv := reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem())
pv.(*reflect.PtrValue).PointTo(zv)
val := pv.Interface() val := pv.Interface()
_, err := Unmarshal(test.in, val) _, err := Unmarshal(test.in, val)
if err != nil { if err != nil {
......
...@@ -133,14 +133,14 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) { ...@@ -133,14 +133,14 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) {
case enumeratedType: case enumeratedType:
return tagEnum, false, true return tagEnum, false, true
} }
switch t := t.(type) { switch t.Kind() {
case *reflect.BoolType: case reflect.Bool:
return tagBoolean, false, true return tagBoolean, false, true
case *reflect.IntType: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return tagInteger, false, true return tagInteger, false, true
case *reflect.StructType: case reflect.Struct:
return tagSequence, true, true return tagSequence, true, true
case *reflect.SliceType: case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 { if t.Elem().Kind() == reflect.Uint8 {
return tagOctetString, false, true return tagOctetString, false, true
} }
...@@ -148,7 +148,7 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) { ...@@ -148,7 +148,7 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) {
return tagSet, true, true return tagSet, true, true
} }
return tagSequence, true, true return tagSequence, true, true
case *reflect.StringType: case reflect.String:
return tagPrintableString, false, true return tagPrintableString, false, true
} }
return 0, false, false return 0, false, false
......
...@@ -125,6 +125,28 @@ func int64Length(i int64) (numBytes int) { ...@@ -125,6 +125,28 @@ func int64Length(i int64) (numBytes int) {
return return
} }
func marshalLength(out *forkableWriter, i int) (err os.Error) {
n := lengthLength(i)
for ; n > 0; n-- {
err = out.WriteByte(byte(i >> uint((n-1)*8)))
if err != nil {
return
}
}
return nil
}
func lengthLength(i int) (numBytes int) {
numBytes = 1
for i > 255 {
numBytes++
i >>= 8
}
return
}
func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err os.Error) { func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err os.Error) {
b := uint8(t.class) << 6 b := uint8(t.class) << 6
if t.isCompound { if t.isCompound {
...@@ -149,12 +171,12 @@ func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err os.Error) { ...@@ -149,12 +171,12 @@ func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err os.Error) {
} }
if t.length >= 128 { if t.length >= 128 {
l := int64Length(int64(t.length)) l := lengthLength(t.length)
err = out.WriteByte(0x80 | byte(l)) err = out.WriteByte(0x80 | byte(l))
if err != nil { if err != nil {
return return
} }
err = marshalInt64(out, int64(t.length)) err = marshalLength(out, t.length)
if err != nil { if err != nil {
return return
} }
...@@ -314,28 +336,28 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter ...@@ -314,28 +336,28 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier)) return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
} }
switch v := value.(type) { switch v := value; v.Kind() {
case *reflect.BoolValue: case reflect.Bool:
if v.Get() { if v.Bool() {
return out.WriteByte(255) return out.WriteByte(255)
} else { } else {
return out.WriteByte(0) return out.WriteByte(0)
} }
case *reflect.IntValue: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return marshalInt64(out, int64(v.Get())) return marshalInt64(out, int64(v.Int()))
case *reflect.StructValue: case reflect.Struct:
t := v.Type().(*reflect.StructType) t := v.Type()
startingField := 0 startingField := 0
// If the first element of the structure is a non-empty // If the first element of the structure is a non-empty
// RawContents, then we don't bother serialising the rest. // RawContents, then we don't bother serialising the rest.
if t.NumField() > 0 && t.Field(0).Type == rawContentsType { if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
s := v.Field(0).(*reflect.SliceValue) s := v.Field(0)
if s.Len() > 0 { if s.Len() > 0 {
bytes := make([]byte, s.Len()) bytes := make([]byte, s.Len())
for i := 0; i < s.Len(); i++ { for i := 0; i < s.Len(); i++ {
bytes[i] = uint8(s.Elem(i).(*reflect.UintValue).Get()) bytes[i] = uint8(s.Index(i).Uint())
} }
/* The RawContents will contain the tag and /* The RawContents will contain the tag and
* length fields but we'll also be writing * length fields but we'll also be writing
...@@ -357,12 +379,12 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter ...@@ -357,12 +379,12 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
} }
} }
return return
case *reflect.SliceValue: case reflect.Slice:
sliceType := v.Type().(*reflect.SliceType) sliceType := v.Type()
if sliceType.Elem().Kind() == reflect.Uint8 { if sliceType.Elem().Kind() == reflect.Uint8 {
bytes := make([]byte, v.Len()) bytes := make([]byte, v.Len())
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
bytes[i] = uint8(v.Elem(i).(*reflect.UintValue).Get()) bytes[i] = uint8(v.Index(i).Uint())
} }
_, err = out.Write(bytes) _, err = out.Write(bytes)
return return
...@@ -372,17 +394,17 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter ...@@ -372,17 +394,17 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
var pre *forkableWriter var pre *forkableWriter
pre, out = out.fork() pre, out = out.fork()
err = marshalField(pre, v.Elem(i), params) err = marshalField(pre, v.Index(i), params)
if err != nil { if err != nil {
return return
} }
} }
return return
case *reflect.StringValue: case reflect.String:
if params.stringType == tagIA5String { if params.stringType == tagIA5String {
return marshalIA5String(out, v.Get()) return marshalIA5String(out, v.String())
} else { } else {
return marshalPrintableString(out, v.Get()) return marshalPrintableString(out, v.String())
} }
return return
} }
...@@ -392,7 +414,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter ...@@ -392,7 +414,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err os.Error) { func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err os.Error) {
// If the field is an interface{} then recurse into it. // If the field is an interface{} then recurse into it.
if v, ok := v.(*reflect.InterfaceValue); ok && v.Type().(*reflect.InterfaceType).NumMethod() == 0 { if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
return marshalField(out, v.Elem(), params) return marshalField(out, v.Elem(), params)
} }
...@@ -406,7 +428,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) ...@@ -406,7 +428,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
return return
} }
if params.optional && reflect.DeepEqual(v.Interface(), reflect.MakeZero(v.Type()).Interface()) { if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return return
} }
...@@ -471,7 +493,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) ...@@ -471,7 +493,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
// Marshal returns the ASN.1 encoding of val. // Marshal returns the ASN.1 encoding of val.
func Marshal(val interface{}) ([]byte, os.Error) { func Marshal(val interface{}) ([]byte, os.Error) {
var out bytes.Buffer var out bytes.Buffer
v := reflect.NewValue(val) v := reflect.ValueOf(val)
f := newForkableWriter() f := newForkableWriter()
err := marshalField(f, v, fieldParameters{}) err := marshalField(f, v, fieldParameters{})
if err != nil { if err != nil {
......
...@@ -77,6 +77,30 @@ var marshalTests = []marshalTest{ ...@@ -77,6 +77,30 @@ var marshalTests = []marshalTest{
{ObjectIdentifier([]int{1, 2, 3, 4}), "06032a0304"}, {ObjectIdentifier([]int{1, 2, 3, 4}), "06032a0304"},
{ObjectIdentifier([]int{1, 2, 840, 133549, 1, 1, 5}), "06092a864888932d010105"}, {ObjectIdentifier([]int{1, 2, 840, 133549, 1, 1, 5}), "06092a864888932d010105"},
{"test", "130474657374"}, {"test", "130474657374"},
{
"" +
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" +
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" +
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" +
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", // This is 127 times 'x'
"137f" +
"7878787878787878787878787878787878787878787878787878787878787878" +
"7878787878787878787878787878787878787878787878787878787878787878" +
"7878787878787878787878787878787878787878787878787878787878787878" +
"78787878787878787878787878787878787878787878787878787878787878",
},
{
"" +
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" +
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" +
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" +
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", // This is 128 times 'x'
"138180" +
"7878787878787878787878787878787878787878787878787878787878787878" +
"7878787878787878787878787878787878787878787878787878787878787878" +
"7878787878787878787878787878787878787878787878787878787878787878" +
"7878787878787878787878787878787878787878787878787878787878787878",
},
{ia5StringTest{"test"}, "3006160474657374"}, {ia5StringTest{"test"}, "3006160474657374"},
{printableStringTest{"test"}, "3006130474657374"}, {printableStringTest{"test"}, "3006130474657374"},
{printableStringTest{"test*"}, "30071305746573742a"}, {printableStringTest{"test*"}, "30071305746573742a"},
......
...@@ -337,6 +337,10 @@ func fmtbase(ch int) int { ...@@ -337,6 +337,10 @@ func fmtbase(ch int) int {
// 'x' (hexadecimal). // 'x' (hexadecimal).
// //
func (x *Int) Format(s fmt.State, ch int) { func (x *Int) Format(s fmt.State, ch int) {
if x == nil {
fmt.Fprint(s, "<nil>")
return
}
if x.neg { if x.neg {
fmt.Fprint(s, "-") fmt.Fprint(s, "-")
} }
......
...@@ -2,11 +2,7 @@ ...@@ -2,11 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This file contains operations on unsigned multi-precision integers. // Package big implements multi-precision arithmetic (big numbers).
// These are the building blocks for the operations on signed integers
// and rationals.
// This package implements multi-precision arithmetic (big numbers).
// The following numeric types are supported: // The following numeric types are supported:
// //
// - Int signed integers // - Int signed integers
...@@ -18,6 +14,10 @@ ...@@ -18,6 +14,10 @@
// //
package big package big
// This file contains operations on unsigned multi-precision integers.
// These are the building blocks for the operations on signed integers
// and rationals.
import "rand" import "rand"
// An unsigned integer x of the form // An unsigned integer x of the form
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements buffered I/O. It wraps an io.Reader or io.Writer // Package bufio implements buffered I/O. It wraps an io.Reader or io.Writer
// object, creating another object (Reader or Writer) that also implements // object, creating another object (Reader or Writer) that also implements
// the interface but provides buffering and some help for textual I/O. // the interface but provides buffering and some help for textual I/O.
package bufio package bufio
...@@ -282,6 +282,33 @@ func (b *Reader) ReadSlice(delim byte) (line []byte, err os.Error) { ...@@ -282,6 +282,33 @@ func (b *Reader) ReadSlice(delim byte) (line []byte, err os.Error) {
panic("not reached") panic("not reached")
} }
// ReadLine tries to return a single line, not including the end-of-line bytes.
// If the line was too long for the buffer then isPrefix is set and the
// beginning of the line is returned. The rest of the line will be returned
// from future calls. isPrefix will be false when returning the last fragment
// of the line. The returned buffer is only valid until the next call to
// ReadLine. ReadLine either returns a non-nil line or it returns an error,
// never both.
func (b *Reader) ReadLine() (line []byte, isPrefix bool, err os.Error) {
line, err = b.ReadSlice('\n')
if err == ErrBufferFull {
return line, true, nil
}
if len(line) == 0 {
return
}
err = nil
if line[len(line)-1] == '\n' {
line = line[:len(line)-1]
}
if len(line) > 0 && line[len(line)-1] == '\r' {
line = line[:len(line)-1]
}
return
}
// ReadBytes reads until the first occurrence of delim in the input, // ReadBytes reads until the first occurrence of delim in the input,
// returning a slice containing the data up to and including the delimiter. // returning a slice containing the data up to and including the delimiter.
// If ReadBytes encounters an error before finding a delimiter, // If ReadBytes encounters an error before finding a delimiter,
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"strings" "strings"
"testing" "testing"
...@@ -570,3 +571,128 @@ func TestPeekThenUnreadRune(t *testing.T) { ...@@ -570,3 +571,128 @@ func TestPeekThenUnreadRune(t *testing.T) {
r.UnreadRune() r.UnreadRune()
r.ReadRune() // Used to panic here r.ReadRune() // Used to panic here
} }
var testOutput = []byte("0123456789abcdefghijklmnopqrstuvwxy")
var testInput = []byte("012\n345\n678\n9ab\ncde\nfgh\nijk\nlmn\nopq\nrst\nuvw\nxy")
var testInputrn = []byte("012\r\n345\r\n678\r\n9ab\r\ncde\r\nfgh\r\nijk\r\nlmn\r\nopq\r\nrst\r\nuvw\r\nxy\r\n\n\r\n")
// TestReader wraps a []byte and returns reads of a specific length.
type testReader struct {
data []byte
stride int
}
func (t *testReader) Read(buf []byte) (n int, err os.Error) {
n = t.stride
if n > len(t.data) {
n = len(t.data)
}
if n > len(buf) {
n = len(buf)
}
copy(buf, t.data)
t.data = t.data[n:]
if len(t.data) == 0 {
err = os.EOF
}
return
}
func testReadLine(t *testing.T, input []byte) {
//for stride := 1; stride < len(input); stride++ {
for stride := 1; stride < 2; stride++ {
done := 0
reader := testReader{input, stride}
l, _ := NewReaderSize(&reader, len(input)+1)
for {
line, isPrefix, err := l.ReadLine()
if len(line) > 0 && err != nil {
t.Errorf("ReadLine returned both data and error: %s", err)
}
if isPrefix {
t.Errorf("ReadLine returned prefix")
}
if err != nil {
if err != os.EOF {
t.Fatalf("Got unknown error: %s", err)
}
break
}
if want := testOutput[done : done+len(line)]; !bytes.Equal(want, line) {
t.Errorf("Bad line at stride %d: want: %x got: %x", stride, want, line)
}
done += len(line)
}
if done != len(testOutput) {
t.Errorf("ReadLine didn't return everything: got: %d, want: %d (stride: %d)", done, len(testOutput), stride)
}
}
}
func TestReadLine(t *testing.T) {
testReadLine(t, testInput)
testReadLine(t, testInputrn)
}
func TestLineTooLong(t *testing.T) {
buf := bytes.NewBuffer([]byte("aaabbbcc\n"))
l, _ := NewReaderSize(buf, 3)
line, isPrefix, err := l.ReadLine()
if !isPrefix || !bytes.Equal(line, []byte("aaa")) || err != nil {
t.Errorf("bad result for first line: %x %s", line, err)
}
line, isPrefix, err = l.ReadLine()
if !isPrefix || !bytes.Equal(line, []byte("bbb")) || err != nil {
t.Errorf("bad result for second line: %x", line)
}
line, isPrefix, err = l.ReadLine()
if isPrefix || !bytes.Equal(line, []byte("cc")) || err != nil {
t.Errorf("bad result for third line: %x", line)
}
line, isPrefix, err = l.ReadLine()
if isPrefix || err == nil {
t.Errorf("expected no more lines: %x %s", line, err)
}
}
func TestReadAfterLines(t *testing.T) {
line1 := "line1"
restData := "line2\nline 3\n"
inbuf := bytes.NewBuffer([]byte(line1 + "\n" + restData))
outbuf := new(bytes.Buffer)
maxLineLength := len(line1) + len(restData)/2
l, _ := NewReaderSize(inbuf, maxLineLength)
line, isPrefix, err := l.ReadLine()
if isPrefix || err != nil || string(line) != line1 {
t.Errorf("bad result for first line: isPrefix=%v err=%v line=%q", isPrefix, err, string(line))
}
n, err := io.Copy(outbuf, l)
if int(n) != len(restData) || err != nil {
t.Errorf("bad result for Read: n=%d err=%v", n, err)
}
if outbuf.String() != restData {
t.Errorf("bad result for Read: got %q; expected %q", outbuf.String(), restData)
}
}
func TestReadEmptyBuffer(t *testing.T) {
l, _ := NewReaderSize(bytes.NewBuffer(nil), 10)
line, isPrefix, err := l.ReadLine()
if err != os.EOF {
t.Errorf("expected EOF from ReadLine, got '%s' %t %s", line, isPrefix, err)
}
}
func TestLinesAfterRead(t *testing.T) {
l, _ := NewReaderSize(bytes.NewBuffer([]byte("foo")), 10)
_, err := ioutil.ReadAll(l)
if err != nil {
t.Error(err)
return
}
line, isPrefix, err := l.ReadLine()
if err != os.EOF {
t.Errorf("expected EOF from ReadLine, got '%s' %t %s", line, isPrefix, err)
}
}
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The bytes package implements functions for the manipulation of byte slices. // Package bytes implements functions for the manipulation of byte slices.
// Analogous to the facilities of the strings package. // It is analogous to the facilities of the strings package.
package bytes package bytes
import ( import (
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The cmath package provides basic constants // Package cmath provides basic constants and mathematical functions for
// and mathematical functions for complex numbers. // complex numbers.
package cmath package cmath
import "math" import "math"
......
...@@ -143,10 +143,18 @@ func (d *compressor) fillWindow(index int) (int, os.Error) { ...@@ -143,10 +143,18 @@ func (d *compressor) fillWindow(index int) (int, os.Error) {
d.blockStart = math.MaxInt32 d.blockStart = math.MaxInt32
} }
for i, h := range d.hashHead { for i, h := range d.hashHead {
d.hashHead[i] = max(h-wSize, -1) v := h - wSize
if v < -1 {
v = -1
}
d.hashHead[i] = v
} }
for i, h := range d.hashPrev { for i, h := range d.hashPrev {
d.hashPrev[i] = max(h-wSize, -1) v := -h - wSize
if v < -1 {
v = -1
}
d.hashPrev[i] = v
} }
} }
count, err := d.r.Read(d.window[d.windowEnd:]) count, err := d.r.Read(d.window[d.windowEnd:])
...@@ -177,10 +185,18 @@ func (d *compressor) writeBlock(tokens []token, index int, eof bool) os.Error { ...@@ -177,10 +185,18 @@ func (d *compressor) writeBlock(tokens []token, index int, eof bool) os.Error {
// Try to find a match starting at index whose length is greater than prevSize. // Try to find a match starting at index whose length is greater than prevSize.
// We only look at chainCount possibilities before giving up. // We only look at chainCount possibilities before giving up.
func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead int) (length, offset int, ok bool) { func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead int) (length, offset int, ok bool) {
win := d.window[0 : pos+min(maxMatchLength, lookahead)] minMatchLook := maxMatchLength
if lookahead < minMatchLook {
minMatchLook = lookahead
}
win := d.window[0 : pos+minMatchLook]
// We quit when we get a match that's at least nice long // We quit when we get a match that's at least nice long
nice := min(d.niceMatch, len(win)-pos) nice := len(win) - pos
if d.niceMatch < nice {
nice = d.niceMatch
}
// If we've got a match that's good enough, only look in 1/4 the chain. // If we've got a match that's good enough, only look in 1/4 the chain.
tries := d.maxChainLength tries := d.maxChainLength
...@@ -344,9 +360,12 @@ Loop: ...@@ -344,9 +360,12 @@ Loop:
} }
prevLength := length prevLength := length
prevOffset := offset prevOffset := offset
minIndex := max(index-maxOffset, 0)
length = minMatchLength - 1 length = minMatchLength - 1
offset = 0 offset = 0
minIndex := index - maxOffset
if minIndex < 0 {
minIndex = 0
}
if chainHead >= minIndex && if chainHead >= minIndex &&
(isFastDeflate && lookahead > minMatchLength-1 || (isFastDeflate && lookahead > minMatchLength-1 ||
...@@ -477,6 +496,33 @@ func NewWriter(w io.Writer, level int) *Writer { ...@@ -477,6 +496,33 @@ func NewWriter(w io.Writer, level int) *Writer {
return &Writer{pw, &d} return &Writer{pw, &d}
} }
// NewWriterDict is like NewWriter but initializes the new
// Writer with a preset dictionary. The returned Writer behaves
// as if the dictionary had been written to it without producing
// any compressed output. The compressed data written to w
// can only be decompressed by a Reader initialized with the
// same dictionary.
func NewWriterDict(w io.Writer, level int, dict []byte) *Writer {
dw := &dictWriter{w, false}
zw := NewWriter(dw, level)
zw.Write(dict)
zw.Flush()
dw.enabled = true
return zw
}
type dictWriter struct {
w io.Writer
enabled bool
}
func (w *dictWriter) Write(b []byte) (n int, err os.Error) {
if w.enabled {
return w.w.Write(b)
}
return len(b), nil
}
// A Writer takes data written to it and writes the compressed // A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter). // form of that data to an underlying writer (see NewWriter).
type Writer struct { type Writer struct {
......
...@@ -275,3 +275,49 @@ func TestDeflateInflateString(t *testing.T) { ...@@ -275,3 +275,49 @@ func TestDeflateInflateString(t *testing.T) {
} }
testToFromWithLevel(t, 1, gold, "2.718281828...") testToFromWithLevel(t, 1, gold, "2.718281828...")
} }
func TestReaderDict(t *testing.T) {
const (
dict = "hello world"
text = "hello again world"
)
var b bytes.Buffer
w := NewWriter(&b, 5)
w.Write([]byte(dict))
w.Flush()
b.Reset()
w.Write([]byte(text))
w.Close()
r := NewReaderDict(&b, []byte(dict))
data, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
if string(data) != "hello again world" {
t.Fatalf("read returned %q want %q", string(data), text)
}
}
func TestWriterDict(t *testing.T) {
const (
dict = "hello world"
text = "hello again world"
)
var b bytes.Buffer
w := NewWriter(&b, 5)
w.Write([]byte(dict))
w.Flush()
b.Reset()
w.Write([]byte(text))
w.Close()
var b1 bytes.Buffer
w = NewWriterDict(&b1, 5, []byte(dict))
w.Write([]byte(text))
w.Close()
if !bytes.Equal(b1.Bytes(), b.Bytes()) {
t.Fatalf("writer wrote %q want %q", b1.Bytes(), b.Bytes())
}
}
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The flate package implements the DEFLATE compressed data // Package flate implements the DEFLATE compressed data format, described in
// format, described in RFC 1951. The gzip and zlib packages // RFC 1951. The gzip and zlib packages implement access to DEFLATE-based file
// implement access to DEFLATE-based file formats. // formats.
package flate package flate
import ( import (
...@@ -526,6 +526,20 @@ func (f *decompressor) dataBlock() os.Error { ...@@ -526,6 +526,20 @@ func (f *decompressor) dataBlock() os.Error {
return nil return nil
} }
func (f *decompressor) setDict(dict []byte) {
if len(dict) > len(f.hist) {
// Will only remember the tail.
dict = dict[len(dict)-len(f.hist):]
}
f.hp = copy(f.hist[:], dict)
if f.hp == len(f.hist) {
f.hp = 0
f.hfull = true
}
f.hw = f.hp
}
func (f *decompressor) moreBits() os.Error { func (f *decompressor) moreBits() os.Error {
c, err := f.r.ReadByte() c, err := f.r.ReadByte()
if err != nil { if err != nil {
...@@ -618,3 +632,16 @@ func NewReader(r io.Reader) io.ReadCloser { ...@@ -618,3 +632,16 @@ func NewReader(r io.Reader) io.ReadCloser {
go func() { pw.CloseWithError(f.decompress(r, pw)) }() go func() { pw.CloseWithError(f.decompress(r, pw)) }()
return pr return pr
} }
// NewReaderDict is like NewReader but initializes the reader
// with a preset dictionary. The returned Reader behaves as if
// the uncompressed data stream started with the given dictionary,
// which has already been read. NewReaderDict is typically used
// to read data compressed by NewWriterDict.
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
var f decompressor
f.setDict(dict)
pr, pw := io.Pipe()
go func() { pw.CloseWithError(f.decompress(r, pw)) }()
return pr
}
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The gzip package implements reading and writing of // Package gzip implements reading and writing of gzip format compressed files,
// gzip format compressed files, as specified in RFC 1952. // as specified in RFC 1952.
package gzip package gzip
import ( import (
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The lzw package implements the Lempel-Ziv-Welch compressed data format, // Package lzw implements the Lempel-Ziv-Welch compressed data format,
// described in T. A. Welch, ``A Technique for High-Performance Data // described in T. A. Welch, ``A Technique for High-Performance Data
// Compression'', Computer, 17(6) (June 1984), pp 8-19. // Compression'', Computer, 17(6) (June 1984), pp 8-19.
// //
...@@ -165,16 +165,19 @@ func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os ...@@ -165,16 +165,19 @@ func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os
if _, err := w.Write(buf[i:]); err != nil { if _, err := w.Write(buf[i:]); err != nil {
return err return err
} }
// Save what the hi code expands to. if last != invalidCode {
suffix[hi] = uint8(c) // Save what the hi code expands to.
prefix[hi] = last suffix[hi] = uint8(c)
prefix[hi] = last
}
default: default:
return os.NewError("lzw: invalid code") return os.NewError("lzw: invalid code")
} }
last, hi = code, hi+1 last, hi = code, hi+1
if hi == overflow { if hi >= overflow {
if d.width == maxWidth { if d.width == maxWidth {
return os.NewError("lzw: missing clear code") last = invalidCode
continue
} }
d.width++ d.width++
overflow <<= 1 overflow <<= 1
......
...@@ -112,12 +112,6 @@ func TestReader(t *testing.T) { ...@@ -112,12 +112,6 @@ func TestReader(t *testing.T) {
} }
} }
type devNull struct{}
func (devNull) Write(p []byte) (int, os.Error) {
return len(p), nil
}
func benchmarkDecoder(b *testing.B, n int) { func benchmarkDecoder(b *testing.B, n int) {
b.StopTimer() b.StopTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
...@@ -134,7 +128,7 @@ func benchmarkDecoder(b *testing.B, n int) { ...@@ -134,7 +128,7 @@ func benchmarkDecoder(b *testing.B, n int) {
runtime.GC() runtime.GC()
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
io.Copy(devNull{}, NewReader(bytes.NewBuffer(buf1), LSB, 8)) io.Copy(ioutil.Discard, NewReader(bytes.NewBuffer(buf1), LSB, 8))
} }
} }
......
...@@ -113,7 +113,7 @@ func benchmarkEncoder(b *testing.B, n int) { ...@@ -113,7 +113,7 @@ func benchmarkEncoder(b *testing.B, n int) {
runtime.GC() runtime.GC()
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
w := NewWriter(devNull{}, LSB, 8) w := NewWriter(ioutil.Discard, LSB, 8)
w.Write(buf1) w.Write(buf1)
w.Close() w.Close()
} }
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
/* /*
The zlib package implements reading and writing of zlib Package zlib implements reading and writing of zlib format compressed data,
format compressed data, as specified in RFC 1950. as specified in RFC 1950.
The implementation provides filters that uncompress during reading The implementation provides filters that uncompress during reading
and compress during writing. For example, to write compressed data and compress during writing. For example, to write compressed data
...@@ -36,7 +36,7 @@ const zlibDeflate = 8 ...@@ -36,7 +36,7 @@ const zlibDeflate = 8
var ChecksumError os.Error = os.ErrorString("zlib checksum error") var ChecksumError os.Error = os.ErrorString("zlib checksum error")
var HeaderError os.Error = os.ErrorString("invalid zlib header") var HeaderError os.Error = os.ErrorString("invalid zlib header")
var UnsupportedError os.Error = os.ErrorString("unsupported zlib format") var DictionaryError os.Error = os.ErrorString("invalid zlib dictionary")
type reader struct { type reader struct {
r flate.Reader r flate.Reader
...@@ -50,6 +50,12 @@ type reader struct { ...@@ -50,6 +50,12 @@ type reader struct {
// The implementation buffers input and may read more data than necessary from r. // The implementation buffers input and may read more data than necessary from r.
// It is the caller's responsibility to call Close on the ReadCloser when done. // It is the caller's responsibility to call Close on the ReadCloser when done.
func NewReader(r io.Reader) (io.ReadCloser, os.Error) { func NewReader(r io.Reader) (io.ReadCloser, os.Error) {
return NewReaderDict(r, nil)
}
// NewReaderDict is like NewReader but uses a preset dictionary.
// NewReaderDict ignores the dictionary if the compressed data does not refer to it.
func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, os.Error) {
z := new(reader) z := new(reader)
if fr, ok := r.(flate.Reader); ok { if fr, ok := r.(flate.Reader); ok {
z.r = fr z.r = fr
...@@ -65,11 +71,19 @@ func NewReader(r io.Reader) (io.ReadCloser, os.Error) { ...@@ -65,11 +71,19 @@ func NewReader(r io.Reader) (io.ReadCloser, os.Error) {
return nil, HeaderError return nil, HeaderError
} }
if z.scratch[1]&0x20 != 0 { if z.scratch[1]&0x20 != 0 {
// BUG(nigeltao): The zlib package does not implement the FDICT flag. _, err = io.ReadFull(z.r, z.scratch[0:4])
return nil, UnsupportedError if err != nil {
return nil, err
}
checksum := uint32(z.scratch[0])<<24 | uint32(z.scratch[1])<<16 | uint32(z.scratch[2])<<8 | uint32(z.scratch[3])
if checksum != adler32.Checksum(dict) {
return nil, DictionaryError
}
z.decompressor = flate.NewReaderDict(z.r, dict)
} else {
z.decompressor = flate.NewReader(z.r)
} }
z.digest = adler32.New() z.digest = adler32.New()
z.decompressor = flate.NewReader(z.r)
return z, nil return z, nil
} }
......
...@@ -15,6 +15,7 @@ type zlibTest struct { ...@@ -15,6 +15,7 @@ type zlibTest struct {
desc string desc string
raw string raw string
compressed []byte compressed []byte
dict []byte
err os.Error err os.Error
} }
...@@ -27,6 +28,7 @@ var zlibTests = []zlibTest{ ...@@ -27,6 +28,7 @@ var zlibTests = []zlibTest{
"", "",
[]byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01}, []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01},
nil, nil,
nil,
}, },
{ {
"goodbye", "goodbye",
...@@ -37,23 +39,27 @@ var zlibTests = []zlibTest{ ...@@ -37,23 +39,27 @@ var zlibTests = []zlibTest{
0x01, 0x00, 0x28, 0xa5, 0x05, 0x5e, 0x01, 0x00, 0x28, 0xa5, 0x05, 0x5e,
}, },
nil, nil,
nil,
}, },
{ {
"bad header", "bad header",
"", "",
[]byte{0x78, 0x9f, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01}, []byte{0x78, 0x9f, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01},
nil,
HeaderError, HeaderError,
}, },
{ {
"bad checksum", "bad checksum",
"", "",
[]byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0xff}, []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0xff},
nil,
ChecksumError, ChecksumError,
}, },
{ {
"not enough data", "not enough data",
"", "",
[]byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00}, []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00},
nil,
io.ErrUnexpectedEOF, io.ErrUnexpectedEOF,
}, },
{ {
...@@ -64,6 +70,33 @@ var zlibTests = []zlibTest{ ...@@ -64,6 +70,33 @@ var zlibTests = []zlibTest{
0x78, 0x9c, 0xff, 0x78, 0x9c, 0xff,
}, },
nil, nil,
nil,
},
{
"dictionary",
"Hello, World!\n",
[]byte{
0x78, 0xbb, 0x1c, 0x32, 0x04, 0x27, 0xf3, 0x00,
0xb1, 0x75, 0x20, 0x1c, 0x45, 0x2e, 0x00, 0x24,
0x12, 0x04, 0x74,
},
[]byte{
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x0a,
},
nil,
},
{
"wrong dictionary",
"",
[]byte{
0x78, 0xbb, 0x1c, 0x32, 0x04, 0x27, 0xf3, 0x00,
0xb1, 0x75, 0x20, 0x1c, 0x45, 0x2e, 0x00, 0x24,
0x12, 0x04, 0x74,
},
[]byte{
0x48, 0x65, 0x6c, 0x6c,
},
DictionaryError,
}, },
} }
...@@ -71,7 +104,7 @@ func TestDecompressor(t *testing.T) { ...@@ -71,7 +104,7 @@ func TestDecompressor(t *testing.T) {
b := new(bytes.Buffer) b := new(bytes.Buffer)
for _, tt := range zlibTests { for _, tt := range zlibTests {
in := bytes.NewBuffer(tt.compressed) in := bytes.NewBuffer(tt.compressed)
zlib, err := NewReader(in) zlib, err := NewReaderDict(in, tt.dict)
if err != nil { if err != nil {
if err != tt.err { if err != tt.err {
t.Errorf("%s: NewReader: %s", tt.desc, err) t.Errorf("%s: NewReader: %s", tt.desc, err)
......
...@@ -21,56 +21,80 @@ const ( ...@@ -21,56 +21,80 @@ const (
DefaultCompression = flate.DefaultCompression DefaultCompression = flate.DefaultCompression
) )
type writer struct { // A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter).
type Writer struct {
w io.Writer w io.Writer
compressor io.WriteCloser compressor *flate.Writer
digest hash.Hash32 digest hash.Hash32
err os.Error err os.Error
scratch [4]byte scratch [4]byte
} }
// NewWriter calls NewWriterLevel with the default compression level. // NewWriter calls NewWriterLevel with the default compression level.
func NewWriter(w io.Writer) (io.WriteCloser, os.Error) { func NewWriter(w io.Writer) (*Writer, os.Error) {
return NewWriterLevel(w, DefaultCompression) return NewWriterLevel(w, DefaultCompression)
} }
// NewWriterLevel creates a new io.WriteCloser that satisfies writes by compressing data written to w. // NewWriterLevel calls NewWriterDict with no dictionary.
func NewWriterLevel(w io.Writer, level int) (*Writer, os.Error) {
return NewWriterDict(w, level, nil)
}
// NewWriterDict creates a new io.WriteCloser that satisfies writes by compressing data written to w.
// It is the caller's responsibility to call Close on the WriteCloser when done. // It is the caller's responsibility to call Close on the WriteCloser when done.
// level is the compression level, which can be DefaultCompression, NoCompression, // level is the compression level, which can be DefaultCompression, NoCompression,
// or any integer value between BestSpeed and BestCompression (inclusive). // or any integer value between BestSpeed and BestCompression (inclusive).
func NewWriterLevel(w io.Writer, level int) (io.WriteCloser, os.Error) { // dict is the preset dictionary to compress with, or nil to use no dictionary.
z := new(writer) func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, os.Error) {
z := new(Writer)
// ZLIB has a two-byte header (as documented in RFC 1950). // ZLIB has a two-byte header (as documented in RFC 1950).
// The first four bits is the CINFO (compression info), which is 7 for the default deflate window size. // The first four bits is the CINFO (compression info), which is 7 for the default deflate window size.
// The next four bits is the CM (compression method), which is 8 for deflate. // The next four bits is the CM (compression method), which is 8 for deflate.
z.scratch[0] = 0x78 z.scratch[0] = 0x78
// The next two bits is the FLEVEL (compression level). The four values are: // The next two bits is the FLEVEL (compression level). The four values are:
// 0=fastest, 1=fast, 2=default, 3=best. // 0=fastest, 1=fast, 2=default, 3=best.
// The next bit, FDICT, is unused, in this implementation. // The next bit, FDICT, is set if a dictionary is given.
// The final five FCHECK bits form a mod-31 checksum. // The final five FCHECK bits form a mod-31 checksum.
switch level { switch level {
case 0, 1: case 0, 1:
z.scratch[1] = 0x01 z.scratch[1] = 0 << 6
case 2, 3, 4, 5: case 2, 3, 4, 5:
z.scratch[1] = 0x5e z.scratch[1] = 1 << 6
case 6, -1: case 6, -1:
z.scratch[1] = 0x9c z.scratch[1] = 2 << 6
case 7, 8, 9: case 7, 8, 9:
z.scratch[1] = 0xda z.scratch[1] = 3 << 6
default: default:
return nil, os.NewError("level out of range") return nil, os.NewError("level out of range")
} }
if dict != nil {
z.scratch[1] |= 1 << 5
}
z.scratch[1] += uint8(31 - (uint16(z.scratch[0])<<8+uint16(z.scratch[1]))%31)
_, err := w.Write(z.scratch[0:2]) _, err := w.Write(z.scratch[0:2])
if err != nil { if err != nil {
return nil, err return nil, err
} }
if dict != nil {
// The next four bytes are the Adler-32 checksum of the dictionary.
checksum := adler32.Checksum(dict)
z.scratch[0] = uint8(checksum >> 24)
z.scratch[1] = uint8(checksum >> 16)
z.scratch[2] = uint8(checksum >> 8)
z.scratch[3] = uint8(checksum >> 0)
_, err = w.Write(z.scratch[0:4])
if err != nil {
return nil, err
}
}
z.w = w z.w = w
z.compressor = flate.NewWriter(w, level) z.compressor = flate.NewWriter(w, level)
z.digest = adler32.New() z.digest = adler32.New()
return z, nil return z, nil
} }
func (z *writer) Write(p []byte) (n int, err os.Error) { func (z *Writer) Write(p []byte) (n int, err os.Error) {
if z.err != nil { if z.err != nil {
return 0, z.err return 0, z.err
} }
...@@ -86,8 +110,17 @@ func (z *writer) Write(p []byte) (n int, err os.Error) { ...@@ -86,8 +110,17 @@ func (z *writer) Write(p []byte) (n int, err os.Error) {
return return
} }
// Flush flushes the underlying compressor.
func (z *Writer) Flush() os.Error {
if z.err != nil {
return z.err
}
z.err = z.compressor.Flush()
return z.err
}
// Calling Close does not close the wrapped io.Writer originally passed to NewWriter. // Calling Close does not close the wrapped io.Writer originally passed to NewWriter.
func (z *writer) Close() os.Error { func (z *Writer) Close() os.Error {
if z.err != nil { if z.err != nil {
return z.err return z.err
} }
......
...@@ -16,13 +16,19 @@ var filenames = []string{ ...@@ -16,13 +16,19 @@ var filenames = []string{
"../testdata/pi.txt", "../testdata/pi.txt",
} }
// Tests that compressing and then decompressing the given file at the given compression level // Tests that compressing and then decompressing the given file at the given compression level and dictionary
// yields equivalent bytes to the original file. // yields equivalent bytes to the original file.
func testFileLevel(t *testing.T, fn string, level int) { func testFileLevelDict(t *testing.T, fn string, level int, d string) {
// Read dictionary, if given.
var dict []byte
if d != "" {
dict = []byte(d)
}
// Read the file, as golden output. // Read the file, as golden output.
golden, err := os.Open(fn) golden, err := os.Open(fn)
if err != nil { if err != nil {
t.Errorf("%s (level=%d): %v", fn, level, err) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return return
} }
defer golden.Close() defer golden.Close()
...@@ -30,7 +36,7 @@ func testFileLevel(t *testing.T, fn string, level int) { ...@@ -30,7 +36,7 @@ func testFileLevel(t *testing.T, fn string, level int) {
// Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end. // Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end.
raw, err := os.Open(fn) raw, err := os.Open(fn)
if err != nil { if err != nil {
t.Errorf("%s (level=%d): %v", fn, level, err) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return return
} }
piper, pipew := io.Pipe() piper, pipew := io.Pipe()
...@@ -38,9 +44,9 @@ func testFileLevel(t *testing.T, fn string, level int) { ...@@ -38,9 +44,9 @@ func testFileLevel(t *testing.T, fn string, level int) {
go func() { go func() {
defer raw.Close() defer raw.Close()
defer pipew.Close() defer pipew.Close()
zlibw, err := NewWriterLevel(pipew, level) zlibw, err := NewWriterDict(pipew, level, dict)
if err != nil { if err != nil {
t.Errorf("%s (level=%d): %v", fn, level, err) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return return
} }
defer zlibw.Close() defer zlibw.Close()
...@@ -48,7 +54,7 @@ func testFileLevel(t *testing.T, fn string, level int) { ...@@ -48,7 +54,7 @@ func testFileLevel(t *testing.T, fn string, level int) {
for { for {
n, err0 := raw.Read(b[0:]) n, err0 := raw.Read(b[0:])
if err0 != nil && err0 != os.EOF { if err0 != nil && err0 != os.EOF {
t.Errorf("%s (level=%d): %v", fn, level, err0) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
return return
} }
_, err1 := zlibw.Write(b[0:n]) _, err1 := zlibw.Write(b[0:n])
...@@ -57,7 +63,7 @@ func testFileLevel(t *testing.T, fn string, level int) { ...@@ -57,7 +63,7 @@ func testFileLevel(t *testing.T, fn string, level int) {
return return
} }
if err1 != nil { if err1 != nil {
t.Errorf("%s (level=%d): %v", fn, level, err1) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
return return
} }
if err0 == os.EOF { if err0 == os.EOF {
...@@ -65,9 +71,9 @@ func testFileLevel(t *testing.T, fn string, level int) { ...@@ -65,9 +71,9 @@ func testFileLevel(t *testing.T, fn string, level int) {
} }
} }
}() }()
zlibr, err := NewReader(piper) zlibr, err := NewReaderDict(piper, dict)
if err != nil { if err != nil {
t.Errorf("%s (level=%d): %v", fn, level, err) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return return
} }
defer zlibr.Close() defer zlibr.Close()
...@@ -76,20 +82,20 @@ func testFileLevel(t *testing.T, fn string, level int) { ...@@ -76,20 +82,20 @@ func testFileLevel(t *testing.T, fn string, level int) {
b0, err0 := ioutil.ReadAll(golden) b0, err0 := ioutil.ReadAll(golden)
b1, err1 := ioutil.ReadAll(zlibr) b1, err1 := ioutil.ReadAll(zlibr)
if err0 != nil { if err0 != nil {
t.Errorf("%s (level=%d): %v", fn, level, err0) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
return return
} }
if err1 != nil { if err1 != nil {
t.Errorf("%s (level=%d): %v", fn, level, err1) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
return return
} }
if len(b0) != len(b1) { if len(b0) != len(b1) {
t.Errorf("%s (level=%d): length mismatch %d versus %d", fn, level, len(b0), len(b1)) t.Errorf("%s (level=%d, dict=%q): length mismatch %d versus %d", fn, level, d, len(b0), len(b1))
return return
} }
for i := 0; i < len(b0); i++ { for i := 0; i < len(b0); i++ {
if b0[i] != b1[i] { if b0[i] != b1[i] {
t.Errorf("%s (level=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, level, i, b0[i], b1[i]) t.Errorf("%s (level=%d, dict=%q): mismatch at %d, 0x%02x versus 0x%02x\n", fn, level, d, i, b0[i], b1[i])
return return
} }
} }
...@@ -97,10 +103,21 @@ func testFileLevel(t *testing.T, fn string, level int) { ...@@ -97,10 +103,21 @@ func testFileLevel(t *testing.T, fn string, level int) {
func TestWriter(t *testing.T) { func TestWriter(t *testing.T) {
for _, fn := range filenames { for _, fn := range filenames {
testFileLevel(t, fn, DefaultCompression) testFileLevelDict(t, fn, DefaultCompression, "")
testFileLevel(t, fn, NoCompression) testFileLevelDict(t, fn, NoCompression, "")
for level := BestSpeed; level <= BestCompression; level++ {
testFileLevelDict(t, fn, level, "")
}
}
}
func TestWriterDict(t *testing.T) {
const dictionary = "0123456789."
for _, fn := range filenames {
testFileLevelDict(t, fn, DefaultCompression, dictionary)
testFileLevelDict(t, fn, NoCompression, dictionary)
for level := BestSpeed; level <= BestCompression; level++ { for level := BestSpeed; level <= BestCompression; level++ {
testFileLevel(t, fn, level) testFileLevelDict(t, fn, level, dictionary)
} }
} }
} }
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package provides heap operations for any type that implements // Package heap provides heap operations for any type that implements
// heap.Interface. // heap.Interface.
// //
package heap package heap
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package heap package heap_test
import ( import (
"testing" "testing"
"container/vector" "container/vector"
. "container/heap"
) )
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The list package implements a doubly linked list. // Package list implements a doubly linked list.
// //
// To iterate over a list (where l is a *List): // To iterate over a list (where l is a *List):
// for e := l.Front(); e != nil; e = e.Next() { // for e := l.Front(); e != nil; e = e.Next() {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The ring package implements operations on circular lists. // Package ring implements operations on circular lists.
package ring package ring
// A Ring is an element of a circular list, or ring. // A Ring is an element of a circular list, or ring.
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The vector package implements containers for managing sequences // Package vector implements containers for managing sequences of elements.
// of elements. Vectors grow and shrink dynamically as necessary. // Vectors grow and shrink dynamically as necessary.
package vector package vector
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// AES constants - 8720 bytes of initialized data. // Package aes implements AES encryption (formerly Rijndael), as defined in
// U.S. Federal Information Processing Standards Publication 197.
// This package implements AES encryption (formerly Rijndael),
// as defined in U.S. Federal Information Processing Standards Publication 197.
package aes package aes
// This file contains AES constants - 8720 bytes of initialized data.
// http://www.csrc.nist.gov/publications/fips/fips197/fips-197.pdf // http://www.csrc.nist.gov/publications/fips/fips197/fips-197.pdf
// AES is based on the mathematical behavior of binary polynomials // AES is based on the mathematical behavior of binary polynomials
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements Bruce Schneier's Blowfish encryption algorithm. // Package blowfish implements Bruce Schneier's Blowfish encryption algorithm.
package blowfish package blowfish
// The code is a port of Bruce Schneier's C implementation. // The code is a port of Bruce Schneier's C implementation.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements CAST5, as defined in RFC 2144. CAST5 is a common // Package cast5 implements CAST5, as defined in RFC 2144. CAST5 is a common
// OpenPGP cipher. // OpenPGP cipher.
package cast5 package cast5
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The cipher package implements standard block cipher modes // Package cipher implements standard block cipher modes that can be wrapped
// that can be wrapped around low-level block cipher implementations. // around low-level block cipher implementations.
// See http://csrc.nist.gov/groups/ST/toolkit/BCM/current_modes.html // See http://csrc.nist.gov/groups/ST/toolkit/BCM/current_modes.html
// and NIST Special Publication 800-38A. // and NIST Special Publication 800-38A.
package cipher package cipher
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The crypto package collects common cryptographic constants. // Package crypto collects common cryptographic constants.
package crypto package crypto
import ( import (
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The elliptic package implements several standard elliptic curves over prime // Package elliptic implements several standard elliptic curves over prime
// fields // fields.
package elliptic package elliptic
// This package operates, internally, on Jacobian coordinates. For a given // This package operates, internally, on Jacobian coordinates. For a given
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The hmac package implements the Keyed-Hash Message Authentication Code (HMAC) // Package hmac implements the Keyed-Hash Message Authentication Code (HMAC) as
// as defined in U.S. Federal Information Processing Standards Publication 198. // defined in U.S. Federal Information Processing Standards Publication 198.
// An HMAC is a cryptographic hash that uses a key to sign a message. // An HMAC is a cryptographic hash that uses a key to sign a message.
// The receiver verifies the hash by recomputing it using the same key. // The receiver verifies the hash by recomputing it using the same key.
package hmac package hmac
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements the MD4 hash algorithm as defined in RFC 1320. // Package md4 implements the MD4 hash algorithm as defined in RFC 1320.
package md4 package md4
import ( import (
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements the MD5 hash algorithm as defined in RFC 1321. // Package md5 implements the MD5 hash algorithm as defined in RFC 1321.
package md5 package md5
import ( import (
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package parses OCSP responses as specified in RFC 2560. OCSP responses // Package ocsp parses OCSP responses as specified in RFC 2560. OCSP responses
// are signed messages attesting to the validity of a certificate for a small // are signed messages attesting to the validity of a certificate for a small
// period of time. This is used to manage revocation for X.509 certificates. // period of time. This is used to manage revocation for X.509 certificates.
package ocsp package ocsp
......
...@@ -2,15 +2,15 @@ ...@@ -2,15 +2,15 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements OpenPGP ASCII Armor, see RFC 4880. OpenPGP Armor is // Package armor implements OpenPGP ASCII Armor, see RFC 4880. OpenPGP Armor is
// very similar to PEM except that it has an additional CRC checksum. // very similar to PEM except that it has an additional CRC checksum.
package armor package armor
import ( import (
"bufio"
"bytes" "bytes"
"crypto/openpgp/error" "crypto/openpgp/error"
"encoding/base64" "encoding/base64"
"encoding/line"
"io" "io"
"os" "os"
) )
...@@ -63,7 +63,7 @@ var armorEndOfLine = []byte("-----") ...@@ -63,7 +63,7 @@ var armorEndOfLine = []byte("-----")
// lineReader wraps a line based reader. It watches for the end of an armor // lineReader wraps a line based reader. It watches for the end of an armor
// block and records the expected CRC value. // block and records the expected CRC value.
type lineReader struct { type lineReader struct {
in *line.Reader in *bufio.Reader
buf []byte buf []byte
eof bool eof bool
crc uint32 crc uint32
...@@ -156,7 +156,7 @@ func (r *openpgpReader) Read(p []byte) (n int, err os.Error) { ...@@ -156,7 +156,7 @@ func (r *openpgpReader) Read(p []byte) (n int, err os.Error) {
// given Reader is not usable after calling this function: an arbitary amount // given Reader is not usable after calling this function: an arbitary amount
// of data may have been read past the end of the block. // of data may have been read past the end of the block.
func Decode(in io.Reader) (p *Block, err os.Error) { func Decode(in io.Reader) (p *Block, err os.Error) {
r := line.NewReader(in, 100) r, _ := bufio.NewReaderSize(in, 100)
var line []byte var line []byte
ignoreNext := false ignoreNext := false
......
...@@ -18,9 +18,9 @@ var armorEndOfLineOut = []byte("-----\n") ...@@ -18,9 +18,9 @@ var armorEndOfLineOut = []byte("-----\n")
// writeSlices writes its arguments to the given Writer. // writeSlices writes its arguments to the given Writer.
func writeSlices(out io.Writer, slices ...[]byte) (err os.Error) { func writeSlices(out io.Writer, slices ...[]byte) (err os.Error) {
for _, s := range slices { for _, s := range slices {
_, err := out.Write(s) _, err = out.Write(s)
if err != nil { if err != nil {
return return err
} }
} }
return return
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package contains common error types for the OpenPGP packages. // Package error contains common error types for the OpenPGP packages.
package error package error
import ( import (
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package openpgp package openpgp
import ( import (
"crypto/openpgp/armor"
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/openpgp/packet" "crypto/openpgp/packet"
"io" "io"
...@@ -13,6 +14,8 @@ import ( ...@@ -13,6 +14,8 @@ import (
// PublicKeyType is the armor type for a PGP public key. // PublicKeyType is the armor type for a PGP public key.
var PublicKeyType = "PGP PUBLIC KEY BLOCK" var PublicKeyType = "PGP PUBLIC KEY BLOCK"
// PrivateKeyType is the armor type for a PGP private key.
var PrivateKeyType = "PGP PRIVATE KEY BLOCK"
// An Entity represents the components of an OpenPGP key: a primary public key // An Entity represents the components of an OpenPGP key: a primary public key
// (which must be a signing key), one or more identities claimed by that key, // (which must be a signing key), one or more identities claimed by that key,
...@@ -101,37 +104,50 @@ func (el EntityList) DecryptionKeys() (keys []Key) { ...@@ -101,37 +104,50 @@ func (el EntityList) DecryptionKeys() (keys []Key) {
// ReadArmoredKeyRing reads one or more public/private keys from an armor keyring file. // ReadArmoredKeyRing reads one or more public/private keys from an armor keyring file.
func ReadArmoredKeyRing(r io.Reader) (EntityList, os.Error) { func ReadArmoredKeyRing(r io.Reader) (EntityList, os.Error) {
body, err := readArmored(r, PublicKeyType) block, err := armor.Decode(r)
if err == os.EOF {
return nil, error.InvalidArgumentError("no armored data found")
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
if block.Type != PublicKeyType && block.Type != PrivateKeyType {
return nil, error.InvalidArgumentError("expected public or private key block, got: " + block.Type)
}
return ReadKeyRing(body) return ReadKeyRing(block.Body)
} }
// ReadKeyRing reads one or more public/private keys, ignoring unsupported keys. // ReadKeyRing reads one or more public/private keys. Unsupported keys are
// ignored as long as at least a single valid key is found.
func ReadKeyRing(r io.Reader) (el EntityList, err os.Error) { func ReadKeyRing(r io.Reader) (el EntityList, err os.Error) {
packets := packet.NewReader(r) packets := packet.NewReader(r)
var lastUnsupportedError os.Error
for { for {
var e *Entity var e *Entity
e, err = readEntity(packets) e, err = readEntity(packets)
if err != nil { if err != nil {
if _, ok := err.(error.UnsupportedError); ok { if _, ok := err.(error.UnsupportedError); ok {
lastUnsupportedError = err
err = readToNextPublicKey(packets) err = readToNextPublicKey(packets)
} }
if err == os.EOF { if err == os.EOF {
err = nil err = nil
return break
} }
if err != nil { if err != nil {
el = nil el = nil
return break
} }
} else { } else {
el = append(el, e) el = append(el, e)
} }
} }
if len(el) == 0 && err == nil {
err = lastUnsupportedError
}
return return
} }
...@@ -197,25 +213,28 @@ EachPacket: ...@@ -197,25 +213,28 @@ EachPacket:
current.Name = pkt.Id current.Name = pkt.Id
current.UserId = pkt current.UserId = pkt
e.Identities[pkt.Id] = current e.Identities[pkt.Id] = current
p, err = packets.Next()
if err == os.EOF { for {
err = io.ErrUnexpectedEOF p, err = packets.Next()
} if err == os.EOF {
if err != nil { return nil, io.ErrUnexpectedEOF
if _, ok := err.(error.UnsupportedError); ok { } else if err != nil {
return nil, err return nil, err
} }
return nil, error.StructuralError("identity self-signature invalid: " + err.String())
} sig, ok := p.(*packet.Signature)
current.SelfSignature, ok = p.(*packet.Signature) if !ok {
if !ok { return nil, error.StructuralError("user ID packet not followed by self-signature")
return nil, error.StructuralError("user ID packet not followed by self signature") }
}
if current.SelfSignature.SigType != packet.SigTypePositiveCert { if sig.SigType == packet.SigTypePositiveCert && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
return nil, error.StructuralError("user ID self-signature with wrong type") if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil {
} return nil, error.StructuralError("user ID self-signature invalid: " + err.String())
if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, current.SelfSignature); err != nil { }
return nil, error.StructuralError("user ID self-signature invalid: " + err.String()) current.SelfSignature = sig
break
}
current.Signatures = append(current.Signatures, sig)
} }
case *packet.Signature: case *packet.Signature:
if current == nil { if current == nil {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements parsing and serialisation of OpenPGP packets, as // Package packet implements parsing and serialisation of OpenPGP packets, as
// specified in RFC 4880. // specified in RFC 4880.
package packet package packet
......
...@@ -164,8 +164,10 @@ func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err os.Error) { ...@@ -164,8 +164,10 @@ func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err os.Error) {
} }
rsaPriv.D = new(big.Int).SetBytes(d) rsaPriv.D = new(big.Int).SetBytes(d)
rsaPriv.P = new(big.Int).SetBytes(p) rsaPriv.Primes = make([]*big.Int, 2)
rsaPriv.Q = new(big.Int).SetBytes(q) rsaPriv.Primes[0] = new(big.Int).SetBytes(p)
rsaPriv.Primes[1] = new(big.Int).SetBytes(q)
rsaPriv.Precompute()
pk.PrivateKey = rsaPriv pk.PrivateKey = rsaPriv
pk.Encrypted = false pk.Encrypted = false
pk.encryptedData = nil pk.encryptedData = nil
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"hash" "hash"
"io" "io"
"os" "os"
"strconv"
) )
// PublicKey represents an OpenPGP public key. See RFC 4880, section 5.5.2. // PublicKey represents an OpenPGP public key. See RFC 4880, section 5.5.2.
...@@ -47,7 +48,7 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) { ...@@ -47,7 +48,7 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) {
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
err = pk.parseDSA(r) err = pk.parseDSA(r)
default: default:
err = error.UnsupportedError("public key type") err = error.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
} }
if err != nil { if err != nil {
return return
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This openpgp package implements high level operations on OpenPGP messages. // Package openpgp implements high level operations on OpenPGP messages.
package openpgp package openpgp
import ( import (
......
...@@ -230,6 +230,23 @@ func TestDetachedSignatureDSA(t *testing.T) { ...@@ -230,6 +230,23 @@ func TestDetachedSignatureDSA(t *testing.T) {
testDetachedSignature(t, kring, readerFromHex(detachedSignatureDSAHex), signedInput, "binary", testKey3KeyId) testDetachedSignature(t, kring, readerFromHex(detachedSignatureDSAHex), signedInput, "binary", testKey3KeyId)
} }
func TestReadingArmoredPrivateKey(t *testing.T) {
el, err := ReadArmoredKeyRing(bytes.NewBufferString(armoredPrivateKeyBlock))
if err != nil {
t.Error(err)
}
if len(el) != 1 {
t.Errorf("got %d entities, wanted 1\n", len(el))
}
}
func TestNoArmoredData(t *testing.T) {
_, err := ReadArmoredKeyRing(bytes.NewBufferString("foo"))
if _, ok := err.(error.InvalidArgumentError); !ok {
t.Errorf("error was not an InvalidArgumentError: %s", err)
}
}
const testKey1KeyId = 0xA34D7E18C20C31BB const testKey1KeyId = 0xA34D7E18C20C31BB
const testKey3KeyId = 0x338934250CCC0360 const testKey3KeyId = 0x338934250CCC0360
...@@ -259,3 +276,37 @@ const symmetricallyEncryptedCompressedHex = "8c0d04030302eb4a03808145d0d260c92f7 ...@@ -259,3 +276,37 @@ const symmetricallyEncryptedCompressedHex = "8c0d04030302eb4a03808145d0d260c92f7
const dsaTestKeyHex = "9901a2044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794" const dsaTestKeyHex = "9901a2044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794"
const dsaTestKeyPrivateHex = "9501bb044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4d00009f592e0619d823953577d4503061706843317e4fee083db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794" const dsaTestKeyPrivateHex = "9501bb044d6c49de110400cb5ce438cf9250907ac2ba5bf6547931270b89f7c4b53d9d09f4d0213a5ef2ec1f26806d3d259960f872a4a102ef1581ea3f6d6882d15134f21ef6a84de933cc34c47cc9106efe3bd84c6aec12e78523661e29bc1a61f0aab17fa58a627fd5fd33f5149153fbe8cd70edf3d963bc287ef875270ff14b5bfdd1bca4483793923b00a0fe46d76cb6e4cbdc568435cd5480af3266d610d303fe33ae8273f30a96d4d34f42fa28ce1112d425b2e3bf7ea553d526e2db6b9255e9dc7419045ce817214d1a0056dbc8d5289956a4b1b69f20f1105124096e6a438f41f2e2495923b0f34b70642607d45559595c7fe94d7fa85fc41bf7d68c1fd509ebeaa5f315f6059a446b9369c277597e4f474a9591535354c7e7f4fd98a08aa60400b130c24ff20bdfbf683313f5daebf1c9b34b3bdadfc77f2ddd72ee1fb17e56c473664bc21d66467655dd74b9005e3a2bacce446f1920cd7017231ae447b67036c9b431b8179deacd5120262d894c26bc015bffe3d827ba7087ad9b700d2ca1f6d16cc1786581e5dd065f293c31209300f9b0afcc3f7c08dd26d0a22d87580b4d00009f592e0619d823953577d4503061706843317e4fee083db41054657374204b65792033202844534129886204131102002205024d6c49de021b03060b090807030206150802090a0b0416020301021e01021780000a0910338934250ccc03607e0400a0bdb9193e8a6b96fc2dfc108ae848914b504481f100a09c4dc148cb693293a67af24dd40d2b13a9e36794"
const armoredPrivateKeyBlock = `-----BEGIN PGP PRIVATE KEY BLOCK-----
Version: GnuPG v1.4.10 (GNU/Linux)
lQHYBE2rFNoBBADFwqWQIW/DSqcB4yCQqnAFTJ27qS5AnB46ccAdw3u4Greeu3Bp
idpoHdjULy7zSKlwR1EA873dO/k/e11Ml3dlAFUinWeejWaK2ugFP6JjiieSsrKn
vWNicdCS4HTWn0X4sjl0ZiAygw6GNhqEQ3cpLeL0g8E9hnYzJKQ0LWJa0QARAQAB
AAP/TB81EIo2VYNmTq0pK1ZXwUpxCrvAAIG3hwKjEzHcbQznsjNvPUihZ+NZQ6+X
0HCfPAdPkGDCLCb6NavcSW+iNnLTrdDnSI6+3BbIONqWWdRDYJhqZCkqmG6zqSfL
IdkJgCw94taUg5BWP/AAeQrhzjChvpMQTVKQL5mnuZbUCeMCAN5qrYMP2S9iKdnk
VANIFj7656ARKt/nf4CBzxcpHTyB8+d2CtPDKCmlJP6vL8t58Jmih+kHJMvC0dzn
gr5f5+sCAOOe5gt9e0am7AvQWhdbHVfJU0TQJx+m2OiCJAqGTB1nvtBLHdJnfdC9
TnXXQ6ZXibqLyBies/xeY2sCKL5qtTMCAKnX9+9d/5yQxRyrQUHt1NYhaXZnJbHx
q4ytu0eWz+5i68IYUSK69jJ1NWPM0T6SkqpB3KCAIv68VFm9PxqG1KmhSrQIVGVz
dCBLZXmIuAQTAQIAIgUCTasU2gIbAwYLCQgHAwIGFQgCCQoLBBYCAwECHgECF4AA
CgkQO9o98PRieSoLhgQAkLEZex02Qt7vGhZzMwuN0R22w3VwyYyjBx+fM3JFETy1
ut4xcLJoJfIaF5ZS38UplgakHG0FQ+b49i8dMij0aZmDqGxrew1m4kBfjXw9B/v+
eIqpODryb6cOSwyQFH0lQkXC040pjq9YqDsO5w0WYNXYKDnzRV0p4H1pweo2VDid
AdgETasU2gEEAN46UPeWRqKHvA99arOxee38fBt2CI08iiWyI8T3J6ivtFGixSqV
bRcPxYO/qLpVe5l84Nb3X71GfVXlc9hyv7CD6tcowL59hg1E/DC5ydI8K8iEpUmK
/UnHdIY5h8/kqgGxkY/T/hgp5fRQgW1ZoZxLajVlMRZ8W4tFtT0DeA+JABEBAAEA
A/0bE1jaaZKj6ndqcw86jd+QtD1SF+Cf21CWRNeLKnUds4FRRvclzTyUMuWPkUeX
TaNNsUOFqBsf6QQ2oHUBBK4VCHffHCW4ZEX2cd6umz7mpHW6XzN4DECEzOVksXtc
lUC1j4UB91DC/RNQqwX1IV2QLSwssVotPMPqhOi0ZLNY7wIA3n7DWKInxYZZ4K+6
rQ+POsz6brEoRHwr8x6XlHenq1Oki855pSa1yXIARoTrSJkBtn5oI+f8AzrnN0BN
oyeQAwIA/7E++3HDi5aweWrViiul9cd3rcsS0dEnksPhvS0ozCJiHsq/6GFmy7J8
QSHZPteedBnZyNp5jR+H7cIfVN3KgwH/Skq4PsuPhDq5TKK6i8Pc1WW8MA6DXTdU
nLkX7RGmMwjC0DBf7KWAlPjFaONAX3a8ndnz//fy1q7u2l9AZwrj1qa1iJ8EGAEC
AAkFAk2rFNoCGwwACgkQO9o98PRieSo2/QP/WTzr4ioINVsvN1akKuekmEMI3LAp
BfHwatufxxP1U+3Si/6YIk7kuPB9Hs+pRqCXzbvPRrI8NHZBmc8qIGthishdCYad
AHcVnXjtxrULkQFGbGvhKURLvS9WnzD/m1K2zzwxzkPTzT9/Yf06O6Mal5AdugPL
VrM0m72/jnpKo04=
=zNCn
-----END PGP PRIVATE KEY BLOCK-----`
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements the various OpenPGP string-to-key transforms as // Package s2k implements the various OpenPGP string-to-key transforms as
// specified in RFC 4800 section 3.7.1. // specified in RFC 4800 section 3.7.1.
package s2k package s2k
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements RC4 encryption, as defined in Bruce Schneier's // Package rc4 implements RC4 encryption, as defined in Bruce Schneier's
// Applied Cryptography. // Applied Cryptography.
package rc4 package rc4
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements the RIPEMD-160 hash algorithm. // Package ripemd160 implements the RIPEMD-160 hash algorithm.
package ripemd160 package ripemd160
// RIPEMD-160 is designed by by Hans Dobbertin, Antoon Bosselaers, and Bart // RIPEMD-160 is designed by by Hans Dobbertin, Antoon Bosselaers, and Bart
......
...@@ -149,10 +149,10 @@ func nonZeroRandomBytes(s []byte, rand io.Reader) (err os.Error) { ...@@ -149,10 +149,10 @@ func nonZeroRandomBytes(s []byte, rand io.Reader) (err os.Error) {
// precompute a prefix of the digest value that makes a valid ASN1 DER string // precompute a prefix of the digest value that makes a valid ASN1 DER string
// with the correct contents. // with the correct contents.
var hashPrefixes = map[crypto.Hash][]byte{ var hashPrefixes = map[crypto.Hash][]byte{
crypto.MD5: []byte{0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x05, 0x05, 0x00, 0x04, 0x10}, crypto.MD5: {0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x05, 0x05, 0x00, 0x04, 0x10},
crypto.SHA1: []byte{0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14}, crypto.SHA1: {0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14},
crypto.SHA256: []byte{0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20}, crypto.SHA256: {0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20},
crypto.SHA384: []byte{0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30}, crypto.SHA384: {0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30},
crypto.SHA512: {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40}, crypto.SHA512: {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40},
crypto.MD5SHA1: {}, // A special TLS case which doesn't use an ASN1 prefix. crypto.MD5SHA1: {}, // A special TLS case which doesn't use an ASN1 prefix.
crypto.RIPEMD160: {0x30, 0x20, 0x30, 0x08, 0x06, 0x06, 0x28, 0xcf, 0x06, 0x03, 0x00, 0x31, 0x04, 0x14}, crypto.RIPEMD160: {0x30, 0x20, 0x30, 0x08, 0x06, 0x06, 0x28, 0xcf, 0x06, 0x03, 0x00, 0x31, 0x04, 0x14},
......
...@@ -197,12 +197,6 @@ func TestVerifyPKCS1v15(t *testing.T) { ...@@ -197,12 +197,6 @@ func TestVerifyPKCS1v15(t *testing.T) {
} }
} }
func bigFromString(s string) *big.Int {
ret := new(big.Int)
ret.SetString(s, 10)
return ret
}
// In order to generate new test vectors you'll need the PEM form of this key: // In order to generate new test vectors you'll need the PEM form of this key:
// -----BEGIN RSA PRIVATE KEY----- // -----BEGIN RSA PRIVATE KEY-----
// MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0 // MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0
...@@ -216,10 +210,12 @@ func bigFromString(s string) *big.Int { ...@@ -216,10 +210,12 @@ func bigFromString(s string) *big.Int {
var rsaPrivateKey = &PrivateKey{ var rsaPrivateKey = &PrivateKey{
PublicKey: PublicKey{ PublicKey: PublicKey{
N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), N: fromBase10("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"),
E: 65537, E: 65537,
}, },
D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), D: fromBase10("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"),
P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), Primes: []*big.Int{
Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), fromBase10("98920366548084643601728869055592650835572950932266967461790948584315647051443"),
fromBase10("94560208308847015747498523884063394671606671904944666360068158221458669711639"),
},
} }
...@@ -30,7 +30,20 @@ func Test3PrimeKeyGeneration(t *testing.T) { ...@@ -30,7 +30,20 @@ func Test3PrimeKeyGeneration(t *testing.T) {
} }
size := 768 size := 768
priv, err := Generate3PrimeKey(rand.Reader, size) priv, err := GenerateMultiPrimeKey(rand.Reader, 3, size)
if err != nil {
t.Errorf("failed to generate key")
}
testKeyBasics(t, priv)
}
func Test4PrimeKeyGeneration(t *testing.T) {
if testing.Short() {
return
}
size := 768
priv, err := GenerateMultiPrimeKey(rand.Reader, 4, size)
if err != nil { if err != nil {
t.Errorf("failed to generate key") t.Errorf("failed to generate key")
} }
...@@ -45,6 +58,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { ...@@ -45,6 +58,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) {
pub := &priv.PublicKey pub := &priv.PublicKey
m := big.NewInt(42) m := big.NewInt(42)
c := encrypt(new(big.Int), pub, m) c := encrypt(new(big.Int), pub, m)
m2, err := decrypt(nil, priv, c) m2, err := decrypt(nil, priv, c)
if err != nil { if err != nil {
t.Errorf("error while decrypting: %s", err) t.Errorf("error while decrypting: %s", err)
...@@ -59,7 +73,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { ...@@ -59,7 +73,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) {
t.Errorf("error while decrypting (blind): %s", err) t.Errorf("error while decrypting (blind): %s", err)
} }
if m.Cmp(m3) != 0 { if m.Cmp(m3) != 0 {
t.Errorf("(blind) got:%v, want:%v", m3, m) t.Errorf("(blind) got:%v, want:%v (%#v)", m3, m, priv)
} }
} }
...@@ -77,10 +91,12 @@ func BenchmarkRSA2048Decrypt(b *testing.B) { ...@@ -77,10 +91,12 @@ func BenchmarkRSA2048Decrypt(b *testing.B) {
E: 3, E: 3,
}, },
D: fromBase10("9542755287494004433998723259516013739278699355114572217325597900889416163458809501304132487555642811888150937392013824621448709836142886006653296025093941418628992648429798282127303704957273845127141852309016655778568546006839666463451542076964744073572349705538631742281931858219480985907271975884773482372966847639853897890615456605598071088189838676728836833012254065983259638538107719766738032720239892094196108713378822882383694456030043492571063441943847195939549773271694647657549658603365629458610273821292232646334717612674519997533901052790334279661754176490593041941863932308687197618671528035670452762731"), D: fromBase10("9542755287494004433998723259516013739278699355114572217325597900889416163458809501304132487555642811888150937392013824621448709836142886006653296025093941418628992648429798282127303704957273845127141852309016655778568546006839666463451542076964744073572349705538631742281931858219480985907271975884773482372966847639853897890615456605598071088189838676728836833012254065983259638538107719766738032720239892094196108713378822882383694456030043492571063441943847195939549773271694647657549658603365629458610273821292232646334717612674519997533901052790334279661754176490593041941863932308687197618671528035670452762731"),
P: fromBase10("130903255182996722426771613606077755295583329135067340152947172868415809027537376306193179624298874215608270802054347609836776473930072411958753044562214537013874103802006369634761074377213995983876788718033850153719421695468704276694983032644416930879093914927146648402139231293035971427838068945045019075433"), Primes: []*big.Int{
Q: fromBase10("109348945610485453577574767652527472924289229538286649661240938988020367005475727988253438647560958573506159449538793540472829815903949343191091817779240101054552748665267574271163617694640513549693841337820602726596756351006149518830932261246698766355347898158548465400674856021497190430791824869615170301029"), fromBase10("130903255182996722426771613606077755295583329135067340152947172868415809027537376306193179624298874215608270802054347609836776473930072411958753044562214537013874103802006369634761074377213995983876788718033850153719421695468704276694983032644416930879093914927146648402139231293035971427838068945045019075433"),
fromBase10("109348945610485453577574767652527472924289229538286649661240938988020367005475727988253438647560958573506159449538793540472829815903949343191091817779240101054552748665267574271163617694640513549693841337820602726596756351006149518830932261246698766355347898158548465400674856021497190430791824869615170301029"),
},
} }
priv.precompute() priv.Precompute()
c := fromBase10("1000") c := fromBase10("1000")
...@@ -99,11 +115,13 @@ func Benchmark3PrimeRSA2048Decrypt(b *testing.B) { ...@@ -99,11 +115,13 @@ func Benchmark3PrimeRSA2048Decrypt(b *testing.B) {
E: 3, E: 3,
}, },
D: fromBase10("10897585948254795600358846499957366070880176878341177571733155050184921896034527397712889205732614568234385175145686545381899460748279607074689061600935843283397424506622998458510302603922766336783617368686090042765718290914099334449154829375179958369993407724946186243249568928237086215759259909861748642124071874879861299389874230489928271621259294894142840428407196932444474088857746123104978617098858619445675532587787023228852383149557470077802718705420275739737958953794088728369933811184572620857678792001136676902250566845618813972833750098806496641114644760255910789397593428910198080271317419213080834885003"), D: fromBase10("10897585948254795600358846499957366070880176878341177571733155050184921896034527397712889205732614568234385175145686545381899460748279607074689061600935843283397424506622998458510302603922766336783617368686090042765718290914099334449154829375179958369993407724946186243249568928237086215759259909861748642124071874879861299389874230489928271621259294894142840428407196932444474088857746123104978617098858619445675532587787023228852383149557470077802718705420275739737958953794088728369933811184572620857678792001136676902250566845618813972833750098806496641114644760255910789397593428910198080271317419213080834885003"),
P: fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"), Primes: []*big.Int{
Q: fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"), fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"),
R: fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"), fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"),
fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"),
},
} }
priv.precompute() priv.Precompute()
c := fromBase10("1000") c := fromBase10("1000")
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements the SHA1 hash algorithm as defined in RFC 3174. // Package sha1 implements the SHA1 hash algorithm as defined in RFC 3174.
package sha1 package sha1
import ( import (
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements the SHA224 and SHA256 hash algorithms as defined in FIPS 180-2. // Package sha256 implements the SHA224 and SHA256 hash algorithms as defined
// in FIPS 180-2.
package sha256 package sha256
import ( import (
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements the SHA384 and SHA512 hash algorithms as defined in FIPS 180-2. // Package sha512 implements the SHA384 and SHA512 hash algorithms as defined
// in FIPS 180-2.
package sha512 package sha512
import ( import (
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements functions that are often useful in cryptographic // Package subtle implements functions that are often useful in cryptographic
// code but require careful thought to use correctly. // code but require careful thought to use correctly.
package subtle package subtle
......
...@@ -100,6 +100,8 @@ type ConnectionState struct { ...@@ -100,6 +100,8 @@ type ConnectionState struct {
// the certificate chain that was presented by the other side // the certificate chain that was presented by the other side
PeerCertificates []*x509.Certificate PeerCertificates []*x509.Certificate
// the verified certificate chains built from PeerCertificates.
VerifiedChains [][]*x509.Certificate
} }
// A Config structure is used to configure a TLS client or server. After one // A Config structure is used to configure a TLS client or server. After one
...@@ -122,7 +124,7 @@ type Config struct { ...@@ -122,7 +124,7 @@ type Config struct {
// RootCAs defines the set of root certificate authorities // RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates. // that clients use when verifying server certificates.
// If RootCAs is nil, TLS uses the host's root CA set. // If RootCAs is nil, TLS uses the host's root CA set.
RootCAs *CASet RootCAs *x509.CertPool
// NextProtos is a list of supported, application level protocols. // NextProtos is a list of supported, application level protocols.
NextProtos []string NextProtos []string
...@@ -158,7 +160,7 @@ func (c *Config) time() int64 { ...@@ -158,7 +160,7 @@ func (c *Config) time() int64 {
return t() return t()
} }
func (c *Config) rootCAs() *CASet { func (c *Config) rootCAs() *x509.CertPool {
s := c.RootCAs s := c.RootCAs
if s == nil { if s == nil {
s = defaultRoots() s = defaultRoots()
...@@ -178,6 +180,9 @@ func (c *Config) cipherSuites() []uint16 { ...@@ -178,6 +180,9 @@ func (c *Config) cipherSuites() []uint16 {
type Certificate struct { type Certificate struct {
Certificate [][]byte Certificate [][]byte
PrivateKey *rsa.PrivateKey PrivateKey *rsa.PrivateKey
// OCSPStaple contains an optional OCSP response which will be served
// to clients that request it.
OCSPStaple []byte
} }
// A TLS record. // A TLS record.
...@@ -221,7 +226,7 @@ var certFiles = []string{ ...@@ -221,7 +226,7 @@ var certFiles = []string{
var once sync.Once var once sync.Once
func defaultRoots() *CASet { func defaultRoots() *x509.CertPool {
once.Do(initDefaults) once.Do(initDefaults)
return varDefaultRoots return varDefaultRoots
} }
...@@ -236,14 +241,14 @@ func initDefaults() { ...@@ -236,14 +241,14 @@ func initDefaults() {
initDefaultCipherSuites() initDefaultCipherSuites()
} }
var varDefaultRoots *CASet var varDefaultRoots *x509.CertPool
func initDefaultRoots() { func initDefaultRoots() {
roots := NewCASet() roots := x509.NewCertPool()
for _, file := range certFiles { for _, file := range certFiles {
data, err := ioutil.ReadFile(file) data, err := ioutil.ReadFile(file)
if err == nil { if err == nil {
roots.SetFromPEM(data) roots.AppendCertsFromPEM(data)
break break
} }
} }
...@@ -255,7 +260,7 @@ var varDefaultCipherSuites []uint16 ...@@ -255,7 +260,7 @@ var varDefaultCipherSuites []uint16
func initDefaultCipherSuites() { func initDefaultCipherSuites() {
varDefaultCipherSuites = make([]uint16, len(cipherSuites)) varDefaultCipherSuites = make([]uint16, len(cipherSuites))
i := 0 i := 0
for id, _ := range cipherSuites { for id := range cipherSuites {
varDefaultCipherSuites[i] = id varDefaultCipherSuites[i] = id
i++ i++
} }
......
...@@ -34,6 +34,9 @@ type Conn struct { ...@@ -34,6 +34,9 @@ type Conn struct {
cipherSuite uint16 cipherSuite uint16
ocspResponse []byte // stapled OCSP response ocspResponse []byte // stapled OCSP response
peerCertificates []*x509.Certificate peerCertificates []*x509.Certificate
// verifedChains contains the certificate chains that we built, as
// opposed to the ones presented by the server.
verifiedChains [][]*x509.Certificate
clientProtocol string clientProtocol string
clientProtocolFallback bool clientProtocolFallback bool
...@@ -765,6 +768,7 @@ func (c *Conn) ConnectionState() ConnectionState { ...@@ -765,6 +768,7 @@ func (c *Conn) ConnectionState() ConnectionState {
state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
state.CipherSuite = c.cipherSuite state.CipherSuite = c.cipherSuite
state.PeerCertificates = c.peerCertificates state.PeerCertificates = c.peerCertificates
state.VerifiedChains = c.verifiedChains
} }
return state return state
......
...@@ -88,7 +88,6 @@ func (c *Conn) clientHandshake() os.Error { ...@@ -88,7 +88,6 @@ func (c *Conn) clientHandshake() os.Error {
finishedHash.Write(certMsg.marshal()) finishedHash.Write(certMsg.marshal())
certs := make([]*x509.Certificate, len(certMsg.certificates)) certs := make([]*x509.Certificate, len(certMsg.certificates))
chain := NewCASet()
for i, asn1Data := range certMsg.certificates { for i, asn1Data := range certMsg.certificates {
cert, err := x509.ParseCertificate(asn1Data) cert, err := x509.ParseCertificate(asn1Data)
if err != nil { if err != nil {
...@@ -96,47 +95,29 @@ func (c *Conn) clientHandshake() os.Error { ...@@ -96,47 +95,29 @@ func (c *Conn) clientHandshake() os.Error {
return os.ErrorString("failed to parse certificate from server: " + err.String()) return os.ErrorString("failed to parse certificate from server: " + err.String())
} }
certs[i] = cert certs[i] = cert
chain.AddCert(cert)
} }
// If we don't have a root CA set configured then anything is accepted. // If we don't have a root CA set configured then anything is accepted.
// TODO(rsc): Find certificates for OS X 10.6. // TODO(rsc): Find certificates for OS X 10.6.
for cur := certs[0]; c.config.RootCAs != nil; { if c.config.RootCAs != nil {
parent := c.config.RootCAs.FindVerifiedParent(cur) opts := x509.VerifyOptions{
if parent != nil { Roots: c.config.RootCAs,
break CurrentTime: c.config.time(),
DNSName: c.config.ServerName,
Intermediates: x509.NewCertPool(),
} }
parent = chain.FindVerifiedParent(cur) for i, cert := range certs {
if parent == nil { if i == 0 {
c.sendAlert(alertBadCertificate) continue
return os.ErrorString("could not find root certificate for chain") }
opts.Intermediates.AddCert(cert)
} }
c.verifiedChains, err = certs[0].Verify(opts)
if !parent.BasicConstraintsValid || !parent.IsCA { if err != nil {
c.sendAlert(alertBadCertificate) c.sendAlert(alertBadCertificate)
return os.ErrorString("intermediate certificate does not have CA bit set") return err
} }
// KeyUsage status flags are ignored. From Engineering
// Security, Peter Gutmann: A European government CA marked its
// signing certificates as being valid for encryption only, but
// no-one noticed. Another European CA marked its signature
// keys as not being valid for signatures. A different CA
// marked its own trusted root certificate as being invalid for
// certificate signing. Another national CA distributed a
// certificate to be used to encrypt data for the country’s tax
// authority that was marked as only being usable for digital
// signatures but not for encryption. Yet another CA reversed
// the order of the bit flags in the keyUsage due to confusion
// over encoding endianness, essentially setting a random
// keyUsage in certificates that it issued. Another CA created
// a self-invalidating certificate by adding a certificate
// policy statement stipulating that the certificate had to be
// used strictly as specified in the keyUsage, and a keyUsage
// containing a flag indicating that the RSA encryption key
// could only be used for Diffie-Hellman key agreement.
cur = parent
} }
if _, ok := certs[0].PublicKey.(*rsa.PublicKey); !ok { if _, ok := certs[0].PublicKey.(*rsa.PublicKey); !ok {
...@@ -145,7 +126,7 @@ func (c *Conn) clientHandshake() os.Error { ...@@ -145,7 +126,7 @@ func (c *Conn) clientHandshake() os.Error {
c.peerCertificates = certs c.peerCertificates = certs
if serverHello.certStatus { if serverHello.ocspStapling {
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {
return err return err
......
...@@ -306,7 +306,7 @@ type serverHelloMsg struct { ...@@ -306,7 +306,7 @@ type serverHelloMsg struct {
compressionMethod uint8 compressionMethod uint8
nextProtoNeg bool nextProtoNeg bool
nextProtos []string nextProtos []string
certStatus bool ocspStapling bool
} }
func (m *serverHelloMsg) marshal() []byte { func (m *serverHelloMsg) marshal() []byte {
...@@ -327,7 +327,7 @@ func (m *serverHelloMsg) marshal() []byte { ...@@ -327,7 +327,7 @@ func (m *serverHelloMsg) marshal() []byte {
nextProtoLen += len(m.nextProtos) nextProtoLen += len(m.nextProtos)
extensionsLength += nextProtoLen extensionsLength += nextProtoLen
} }
if m.certStatus { if m.ocspStapling {
numExtensions++ numExtensions++
} }
if numExtensions > 0 { if numExtensions > 0 {
...@@ -373,7 +373,7 @@ func (m *serverHelloMsg) marshal() []byte { ...@@ -373,7 +373,7 @@ func (m *serverHelloMsg) marshal() []byte {
z = z[1+l:] z = z[1+l:]
} }
} }
if m.certStatus { if m.ocspStapling {
z[0] = byte(extensionStatusRequest >> 8) z[0] = byte(extensionStatusRequest >> 8)
z[1] = byte(extensionStatusRequest) z[1] = byte(extensionStatusRequest)
z = z[4:] z = z[4:]
...@@ -406,7 +406,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { ...@@ -406,7 +406,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
m.nextProtoNeg = false m.nextProtoNeg = false
m.nextProtos = nil m.nextProtos = nil
m.certStatus = false m.ocspStapling = false
if len(data) == 0 { if len(data) == 0 {
// ServerHello is optionally followed by extension data // ServerHello is optionally followed by extension data
...@@ -450,7 +450,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { ...@@ -450,7 +450,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
if length > 0 { if length > 0 {
return false return false
} }
m.certStatus = true m.ocspStapling = true
} }
data = data[length:] data = data[length:]
} }
......
...@@ -32,7 +32,7 @@ type testMessage interface { ...@@ -32,7 +32,7 @@ type testMessage interface {
func TestMarshalUnmarshal(t *testing.T) { func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0)) rand := rand.New(rand.NewSource(0))
for i, iface := range tests { for i, iface := range tests {
ty := reflect.NewValue(iface).Type() ty := reflect.ValueOf(iface).Type()
n := 100 n := 100
if testing.Short() { if testing.Short() {
...@@ -121,11 +121,11 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -121,11 +121,11 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m.ocspStapling = rand.Intn(10) > 5 m.ocspStapling = rand.Intn(10) > 5
m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
m.supportedCurves = make([]uint16, rand.Intn(5)+1) m.supportedCurves = make([]uint16, rand.Intn(5)+1)
for i, _ := range m.supportedCurves { for i := range m.supportedCurves {
m.supportedCurves[i] = uint16(rand.Intn(30000)) m.supportedCurves[i] = uint16(rand.Intn(30000))
} }
return reflect.NewValue(m) return reflect.ValueOf(m)
} }
func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
...@@ -146,7 +146,7 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -146,7 +146,7 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
} }
} }
return reflect.NewValue(m) return reflect.ValueOf(m)
} }
func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
...@@ -156,7 +156,7 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -156,7 +156,7 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
for i := 0; i < numCerts; i++ { for i := 0; i < numCerts; i++ {
m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
} }
return reflect.NewValue(m) return reflect.ValueOf(m)
} }
func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
...@@ -167,13 +167,13 @@ func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value ...@@ -167,13 +167,13 @@ func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value
for i := 0; i < numCAs; i++ { for i := 0; i < numCAs; i++ {
m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
} }
return reflect.NewValue(m) return reflect.ValueOf(m)
} }
func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateVerifyMsg{} m := &certificateVerifyMsg{}
m.signature = randomBytes(rand.Intn(15)+1, rand) m.signature = randomBytes(rand.Intn(15)+1, rand)
return reflect.NewValue(m) return reflect.ValueOf(m)
} }
func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
...@@ -184,23 +184,23 @@ func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -184,23 +184,23 @@ func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
} else { } else {
m.statusType = 42 m.statusType = 42
} }
return reflect.NewValue(m) return reflect.ValueOf(m)
} }
func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientKeyExchangeMsg{} m := &clientKeyExchangeMsg{}
m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
return reflect.NewValue(m) return reflect.ValueOf(m)
} }
func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &finishedMsg{} m := &finishedMsg{}
m.verifyData = randomBytes(12, rand) m.verifyData = randomBytes(12, rand)
return reflect.NewValue(m) return reflect.ValueOf(m)
} }
func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &nextProtoMsg{} m := &nextProtoMsg{}
m.proto = randomString(rand.Intn(255), rand) m.proto = randomString(rand.Intn(255), rand)
return reflect.NewValue(m) return reflect.ValueOf(m)
} }
...@@ -103,6 +103,9 @@ FindCipherSuite: ...@@ -103,6 +103,9 @@ FindCipherSuite:
hello.nextProtoNeg = true hello.nextProtoNeg = true
hello.nextProtos = config.NextProtos hello.nextProtos = config.NextProtos
} }
if clientHello.ocspStapling && len(config.Certificates[0].OCSPStaple) > 0 {
hello.ocspStapling = true
}
finishedHash.Write(hello.marshal()) finishedHash.Write(hello.marshal())
c.writeRecord(recordTypeHandshake, hello.marshal()) c.writeRecord(recordTypeHandshake, hello.marshal())
...@@ -116,6 +119,14 @@ FindCipherSuite: ...@@ -116,6 +119,14 @@ FindCipherSuite:
finishedHash.Write(certMsg.marshal()) finishedHash.Write(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal()) c.writeRecord(recordTypeHandshake, certMsg.marshal())
if hello.ocspStapling {
certStatus := new(certificateStatusMsg)
certStatus.statusType = statusTypeOCSP
certStatus.response = config.Certificates[0].OCSPStaple
finishedHash.Write(certStatus.marshal())
c.writeRecord(recordTypeHandshake, certStatus.marshal())
}
keyAgreement := suite.ka() keyAgreement := suite.ka()
skx, err := keyAgreement.generateServerKeyExchange(config, clientHello, hello) skx, err := keyAgreement.generateServerKeyExchange(config, clientHello, hello)
......
...@@ -188,8 +188,10 @@ var testPrivateKey = &rsa.PrivateKey{ ...@@ -188,8 +188,10 @@ var testPrivateKey = &rsa.PrivateKey{
E: 65537, E: 65537,
}, },
D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"), D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"),
P: bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"), Primes: []*big.Int{
Q: bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"), bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"),
bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"),
},
} }
// Script of interaction with gnutls implementation. // Script of interaction with gnutls implementation.
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package partially implements the TLS 1.1 protocol, as specified in RFC 4346. // Package tls partially implements the TLS 1.1 protocol, as specified in RFC
// 4346.
package tls package tls
import ( import (
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements Bruce Schneier's Twofish encryption algorithm. // Package twofish implements Bruce Schneier's Twofish encryption algorithm.
package twofish package twofish
// Twofish is defined in http://www.schneier.com/paper-twofish-paper.pdf [TWOFISH] // Twofish is defined in http://www.schneier.com/paper-twofish-paper.pdf [TWOFISH]
......
// Copyright 2009 The Go Authors. All rights reserved. // Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package tls package x509
import ( import (
"crypto/x509"
"encoding/pem" "encoding/pem"
"strings" "strings"
) )
// A CASet is a set of certificates. // Roots is a set of certificates.
type CASet struct { type CertPool struct {
bySubjectKeyId map[string][]*x509.Certificate bySubjectKeyId map[string][]int
byName map[string][]*x509.Certificate byName map[string][]int
certs []*Certificate
} }
// NewCASet returns a new, empty CASet. // NewCertPool returns a new, empty CertPool.
func NewCASet() *CASet { func NewCertPool() *CertPool {
return &CASet{ return &CertPool{
make(map[string][]*x509.Certificate), make(map[string][]int),
make(map[string][]*x509.Certificate), make(map[string][]int),
nil,
} }
} }
func nameToKey(name *x509.Name) string { func nameToKey(name *Name) string {
return strings.Join(name.Country, ",") + "/" + strings.Join(name.Organization, ",") + "/" + strings.Join(name.OrganizationalUnit, ",") + "/" + name.CommonName return strings.Join(name.Country, ",") + "/" + strings.Join(name.Organization, ",") + "/" + strings.Join(name.OrganizationalUnit, ",") + "/" + name.CommonName
} }
// FindVerifiedParent attempts to find the certificate in s which has signed // findVerifiedParents attempts to find certificates in s which have signed the
// the given certificate. If no such certificate can be found or the signature // given certificate. If no such certificate can be found or the signature
// doesn't match, it returns nil. // doesn't match, it returns nil.
func (s *CASet) FindVerifiedParent(cert *x509.Certificate) (parent *x509.Certificate) { func (s *CertPool) findVerifiedParents(cert *Certificate) (parents []int) {
var candidates []*x509.Certificate var candidates []int
if len(cert.AuthorityKeyId) > 0 { if len(cert.AuthorityKeyId) > 0 {
candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)] candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)]
...@@ -42,30 +43,45 @@ func (s *CASet) FindVerifiedParent(cert *x509.Certificate) (parent *x509.Certifi ...@@ -42,30 +43,45 @@ func (s *CASet) FindVerifiedParent(cert *x509.Certificate) (parent *x509.Certifi
} }
for _, c := range candidates { for _, c := range candidates {
if cert.CheckSignatureFrom(c) == nil { if cert.CheckSignatureFrom(s.certs[c]) == nil {
return c parents = append(parents, c)
} }
} }
return nil return
} }
// AddCert adds a certificate to the set // AddCert adds a certificate to a pool.
func (s *CASet) AddCert(cert *x509.Certificate) { func (s *CertPool) AddCert(cert *Certificate) {
if cert == nil {
panic("adding nil Certificate to CertPool")
}
// Check that the certificate isn't being added twice.
for _, c := range s.certs {
if c.Equal(cert) {
return
}
}
n := len(s.certs)
s.certs = append(s.certs, cert)
if len(cert.SubjectKeyId) > 0 { if len(cert.SubjectKeyId) > 0 {
keyId := string(cert.SubjectKeyId) keyId := string(cert.SubjectKeyId)
s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], cert) s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], n)
} }
name := nameToKey(&cert.Subject) name := nameToKey(&cert.Subject)
s.byName[name] = append(s.byName[name], cert) s.byName[name] = append(s.byName[name], n)
} }
// SetFromPEM attempts to parse a series of PEM encoded root certificates. It // AppendCertsFromPEM attempts to parse a series of PEM encoded root
// appends any certificates found to s and returns true if any certificates // certificates. It appends any certificates found to s and returns true if any
// were successfully parsed. On many Linux systems, /etc/ssl/cert.pem will // certificates were successfully parsed.
// contains the system wide set of root CAs in a format suitable for this //
// function. // On many Linux systems, /etc/ssl/cert.pem will contains the system wide set
func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) { // of root CAs in a format suitable for this function.
func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
for len(pemCerts) > 0 { for len(pemCerts) > 0 {
var block *pem.Block var block *pem.Block
block, pemCerts = pem.Decode(pemCerts) block, pemCerts = pem.Decode(pemCerts)
...@@ -76,7 +92,7 @@ func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) { ...@@ -76,7 +92,7 @@ func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) {
continue continue
} }
cert, err := x509.ParseCertificate(block.Bytes) cert, err := ParseCertificate(block.Bytes)
if err != nil { if err != nil {
continue continue
} }
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"os"
"strings"
"time"
)
type InvalidReason int
const (
// NotAuthorizedToSign results when a certificate is signed by another
// which isn't marked as a CA certificate.
NotAuthorizedToSign InvalidReason = iota
// Expired results when a certificate has expired, based on the time
// given in the VerifyOptions.
Expired
// CANotAuthorizedForThisName results when an intermediate or root
// certificate has a name constraint which doesn't include the name
// being checked.
CANotAuthorizedForThisName
)
// CertificateInvalidError results when an odd error occurs. Users of this
// library probably want to handle all these errors uniformly.
type CertificateInvalidError struct {
Cert *Certificate
Reason InvalidReason
}
func (e CertificateInvalidError) String() string {
switch e.Reason {
case NotAuthorizedToSign:
return "x509: certificate is not authorized to sign other other certificates"
case Expired:
return "x509: certificate has expired or is not yet valid"
case CANotAuthorizedForThisName:
return "x509: a root or intermediate certificate is not authorized to sign in this domain"
}
return "x509: unknown error"
}
// HostnameError results when the set of authorized names doesn't match the
// requested name.
type HostnameError struct {
Certificate *Certificate
Host string
}
func (h HostnameError) String() string {
var valid string
c := h.Certificate
if len(c.DNSNames) > 0 {
valid = strings.Join(c.DNSNames, ", ")
} else {
valid = c.Subject.CommonName
}
return "certificate is valid for " + valid + ", not " + h.Host
}
// UnknownAuthorityError results when the certificate issuer is unknown
type UnknownAuthorityError struct {
cert *Certificate
}
func (e UnknownAuthorityError) String() string {
return "x509: certificate signed by unknown authority"
}
// VerifyOptions contains parameters for Certificate.Verify. It's a structure
// because other PKIX verification APIs have ended up needing many options.
type VerifyOptions struct {
DNSName string
Intermediates *CertPool
Roots *CertPool
CurrentTime int64 // if 0, the current system time is used.
}
const (
leafCertificate = iota
intermediateCertificate
rootCertificate
)
// isValid performs validity checks on the c.
func (c *Certificate) isValid(certType int, opts *VerifyOptions) os.Error {
if opts.CurrentTime < c.NotBefore.Seconds() ||
opts.CurrentTime > c.NotAfter.Seconds() {
return CertificateInvalidError{c, Expired}
}
if len(c.PermittedDNSDomains) > 0 {
for _, domain := range c.PermittedDNSDomains {
if opts.DNSName == domain ||
(strings.HasSuffix(opts.DNSName, domain) &&
len(opts.DNSName) >= 1+len(domain) &&
opts.DNSName[len(opts.DNSName)-len(domain)-1] == '.') {
continue
}
return CertificateInvalidError{c, CANotAuthorizedForThisName}
}
}
// KeyUsage status flags are ignored. From Engineering Security, Peter
// Gutmann: A European government CA marked its signing certificates as
// being valid for encryption only, but no-one noticed. Another
// European CA marked its signature keys as not being valid for
// signatures. A different CA marked its own trusted root certificate
// as being invalid for certificate signing. Another national CA
// distributed a certificate to be used to encrypt data for the
// country’s tax authority that was marked as only being usable for
// digital signatures but not for encryption. Yet another CA reversed
// the order of the bit flags in the keyUsage due to confusion over
// encoding endianness, essentially setting a random keyUsage in
// certificates that it issued. Another CA created a self-invalidating
// certificate by adding a certificate policy statement stipulating
// that the certificate had to be used strictly as specified in the
// keyUsage, and a keyUsage containing a flag indicating that the RSA
// encryption key could only be used for Diffie-Hellman key agreement.
if certType == intermediateCertificate && (!c.BasicConstraintsValid || !c.IsCA) {
return CertificateInvalidError{c, NotAuthorizedToSign}
}
return nil
}
// Verify attempts to verify c by building one or more chains from c to a
// certificate in opts.roots, using certificates in opts.Intermediates if
// needed. If successful, it returns one or chains where the first element of
// the chain is c and the last element is from opts.Roots.
//
// WARNING: this doesn't do any revocation checking.
func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err os.Error) {
if opts.CurrentTime == 0 {
opts.CurrentTime = time.Seconds()
}
err = c.isValid(leafCertificate, &opts)
if err != nil {
return
}
if len(opts.DNSName) > 0 {
err = c.VerifyHostname(opts.DNSName)
if err != nil {
return
}
}
return c.buildChains(make(map[int][][]*Certificate), []*Certificate{c}, &opts)
}
func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate {
n := make([]*Certificate, len(chain)+1)
copy(n, chain)
n[len(chain)] = cert
return n
}
func (c *Certificate) buildChains(cache map[int][][]*Certificate, currentChain []*Certificate, opts *VerifyOptions) (chains [][]*Certificate, err os.Error) {
for _, rootNum := range opts.Roots.findVerifiedParents(c) {
root := opts.Roots.certs[rootNum]
err = root.isValid(rootCertificate, opts)
if err != nil {
continue
}
chains = append(chains, appendToFreshChain(currentChain, root))
}
for _, intermediateNum := range opts.Intermediates.findVerifiedParents(c) {
intermediate := opts.Intermediates.certs[intermediateNum]
err = intermediate.isValid(intermediateCertificate, opts)
if err != nil {
continue
}
var childChains [][]*Certificate
childChains, ok := cache[intermediateNum]
if !ok {
childChains, err = intermediate.buildChains(cache, appendToFreshChain(currentChain, intermediate), opts)
cache[intermediateNum] = childChains
}
chains = append(chains, childChains...)
}
if len(chains) > 0 {
err = nil
}
if len(chains) == 0 && err == nil {
err = UnknownAuthorityError{c}
}
return
}
func matchHostnames(pattern, host string) bool {
if len(pattern) == 0 || len(host) == 0 {
return false
}
patternParts := strings.Split(pattern, ".", -1)
hostParts := strings.Split(host, ".", -1)
if len(patternParts) != len(hostParts) {
return false
}
for i, patternPart := range patternParts {
if patternPart == "*" {
continue
}
if patternPart != hostParts[i] {
return false
}
}
return true
}
// VerifyHostname returns nil if c is a valid certificate for the named host.
// Otherwise it returns an os.Error describing the mismatch.
func (c *Certificate) VerifyHostname(h string) os.Error {
if len(c.DNSNames) > 0 {
for _, match := range c.DNSNames {
if matchHostnames(match, h) {
return nil
}
}
// If Subject Alt Name is given, we ignore the common name.
} else if matchHostnames(c.Subject.CommonName, h) {
return nil
}
return HostnameError{c, h}
}
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package parses X.509-encoded keys and certificates. // Package x509 parses X.509-encoded keys and certificates.
package x509 package x509
import ( import (
"asn1" "asn1"
"big" "big"
"bytes"
"container/vector" "container/vector"
"crypto" "crypto"
"crypto/rsa" "crypto/rsa"
...@@ -15,7 +16,6 @@ import ( ...@@ -15,7 +16,6 @@ import (
"hash" "hash"
"io" "io"
"os" "os"
"strings"
"time" "time"
) )
...@@ -27,6 +27,20 @@ type pkcs1PrivateKey struct { ...@@ -27,6 +27,20 @@ type pkcs1PrivateKey struct {
D asn1.RawValue D asn1.RawValue
P asn1.RawValue P asn1.RawValue
Q asn1.RawValue Q asn1.RawValue
// We ignore these values, if present, because rsa will calculate them.
Dp asn1.RawValue "optional"
Dq asn1.RawValue "optional"
Qinv asn1.RawValue "optional"
AdditionalPrimes []pkcs1AddtionalRSAPrime "optional"
}
type pkcs1AddtionalRSAPrime struct {
Prime asn1.RawValue
// We ignore these values because rsa will calculate them.
Exp asn1.RawValue
Coeff asn1.RawValue
} }
// rawValueIsInteger returns true iff the given ASN.1 RawValue is an INTEGER type. // rawValueIsInteger returns true iff the given ASN.1 RawValue is an INTEGER type.
...@@ -46,6 +60,10 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) { ...@@ -46,6 +60,10 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) {
return return
} }
if priv.Version > 1 {
return nil, os.ErrorString("x509: unsupported private key version")
}
if !rawValueIsInteger(&priv.N) || if !rawValueIsInteger(&priv.N) ||
!rawValueIsInteger(&priv.D) || !rawValueIsInteger(&priv.D) ||
!rawValueIsInteger(&priv.P) || !rawValueIsInteger(&priv.P) ||
...@@ -61,26 +79,66 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) { ...@@ -61,26 +79,66 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) {
} }
key.D = new(big.Int).SetBytes(priv.D.Bytes) key.D = new(big.Int).SetBytes(priv.D.Bytes)
key.P = new(big.Int).SetBytes(priv.P.Bytes) key.Primes = make([]*big.Int, 2+len(priv.AdditionalPrimes))
key.Q = new(big.Int).SetBytes(priv.Q.Bytes) key.Primes[0] = new(big.Int).SetBytes(priv.P.Bytes)
key.Primes[1] = new(big.Int).SetBytes(priv.Q.Bytes)
for i, a := range priv.AdditionalPrimes {
if !rawValueIsInteger(&a.Prime) {
return nil, asn1.StructuralError{"tags don't match"}
}
key.Primes[i+2] = new(big.Int).SetBytes(a.Prime.Bytes)
// We ignore the other two values because rsa will calculate
// them as needed.
}
err = key.Validate() err = key.Validate()
if err != nil { if err != nil {
return nil, err return nil, err
} }
key.Precompute()
return return
} }
// rawValueForBig returns an asn1.RawValue which represents the given integer.
func rawValueForBig(n *big.Int) asn1.RawValue {
b := n.Bytes()
if n.Sign() >= 0 && len(b) > 0 && b[0]&0x80 != 0 {
// This positive number would be interpreted as a negative
// number in ASN.1 because the MSB is set.
padded := make([]byte, len(b)+1)
copy(padded[1:], b)
b = padded
}
return asn1.RawValue{Tag: 2, Bytes: b}
}
// MarshalPKCS1PrivateKey converts a private key to ASN.1 DER encoded form. // MarshalPKCS1PrivateKey converts a private key to ASN.1 DER encoded form.
func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte { func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte {
key.Precompute()
version := 0
if len(key.Primes) > 2 {
version = 1
}
priv := pkcs1PrivateKey{ priv := pkcs1PrivateKey{
Version: 1, Version: version,
N: asn1.RawValue{Tag: 2, Bytes: key.PublicKey.N.Bytes()}, N: rawValueForBig(key.N),
E: key.PublicKey.E, E: key.PublicKey.E,
D: asn1.RawValue{Tag: 2, Bytes: key.D.Bytes()}, D: rawValueForBig(key.D),
P: asn1.RawValue{Tag: 2, Bytes: key.P.Bytes()}, P: rawValueForBig(key.Primes[0]),
Q: asn1.RawValue{Tag: 2, Bytes: key.Q.Bytes()}, Q: rawValueForBig(key.Primes[1]),
Dp: rawValueForBig(key.Precomputed.Dp),
Dq: rawValueForBig(key.Precomputed.Dq),
Qinv: rawValueForBig(key.Precomputed.Qinv),
}
priv.AdditionalPrimes = make([]pkcs1AddtionalRSAPrime, len(key.Precomputed.CRTValues))
for i, values := range key.Precomputed.CRTValues {
priv.AdditionalPrimes[i].Prime = rawValueForBig(key.Primes[2+i])
priv.AdditionalPrimes[i].Exp = rawValueForBig(values.Exp)
priv.AdditionalPrimes[i].Coeff = rawValueForBig(values.Coeff)
} }
b, _ := asn1.Marshal(priv) b, _ := asn1.Marshal(priv)
...@@ -90,6 +148,7 @@ func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte { ...@@ -90,6 +148,7 @@ func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte {
// These structures reflect the ASN.1 structure of X.509 certificates.: // These structures reflect the ASN.1 structure of X.509 certificates.:
type certificate struct { type certificate struct {
Raw asn1.RawContent
TBSCertificate tbsCertificate TBSCertificate tbsCertificate
SignatureAlgorithm algorithmIdentifier SignatureAlgorithm algorithmIdentifier
SignatureValue asn1.BitString SignatureValue asn1.BitString
...@@ -127,6 +186,7 @@ type validity struct { ...@@ -127,6 +186,7 @@ type validity struct {
} }
type publicKeyInfo struct { type publicKeyInfo struct {
Raw asn1.RawContent
Algorithm algorithmIdentifier Algorithm algorithmIdentifier
PublicKey asn1.BitString PublicKey asn1.BitString
} }
...@@ -343,7 +403,10 @@ const ( ...@@ -343,7 +403,10 @@ const (
// A Certificate represents an X.509 certificate. // A Certificate represents an X.509 certificate.
type Certificate struct { type Certificate struct {
Raw []byte // Raw ASN.1 DER contents. Raw []byte // Complete ASN.1 DER content (certificate, signature algorithm and signature).
RawTBSCertificate []byte // Certificate part of raw ASN.1 DER content.
RawSubjectPublicKeyInfo []byte // DER encoded SubjectPublicKeyInfo.
Signature []byte Signature []byte
SignatureAlgorithm SignatureAlgorithm SignatureAlgorithm SignatureAlgorithm
...@@ -395,6 +458,10 @@ func (ConstraintViolationError) String() string { ...@@ -395,6 +458,10 @@ func (ConstraintViolationError) String() string {
return "invalid signature: parent certificate cannot sign this kind of certificate" return "invalid signature: parent certificate cannot sign this kind of certificate"
} }
func (c *Certificate) Equal(other *Certificate) bool {
return bytes.Equal(c.Raw, other.Raw)
}
// CheckSignatureFrom verifies that the signature on c is a valid signature // CheckSignatureFrom verifies that the signature on c is a valid signature
// from parent. // from parent.
func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err os.Error) { func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err os.Error) {
...@@ -434,69 +501,12 @@ func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err os.Error) { ...@@ -434,69 +501,12 @@ func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err os.Error) {
return UnsupportedAlgorithmError{} return UnsupportedAlgorithmError{}
} }
h.Write(c.Raw) h.Write(c.RawTBSCertificate)
digest := h.Sum() digest := h.Sum()
return rsa.VerifyPKCS1v15(pub, hashType, digest, c.Signature) return rsa.VerifyPKCS1v15(pub, hashType, digest, c.Signature)
} }
func matchHostnames(pattern, host string) bool {
if len(pattern) == 0 || len(host) == 0 {
return false
}
patternParts := strings.Split(pattern, ".", -1)
hostParts := strings.Split(host, ".", -1)
if len(patternParts) != len(hostParts) {
return false
}
for i, patternPart := range patternParts {
if patternPart == "*" {
continue
}
if patternPart != hostParts[i] {
return false
}
}
return true
}
type HostnameError struct {
Certificate *Certificate
Host string
}
func (h *HostnameError) String() string {
var valid string
c := h.Certificate
if len(c.DNSNames) > 0 {
valid = strings.Join(c.DNSNames, ", ")
} else {
valid = c.Subject.CommonName
}
return "certificate is valid for " + valid + ", not " + h.Host
}
// VerifyHostname returns nil if c is a valid certificate for the named host.
// Otherwise it returns an os.Error describing the mismatch.
func (c *Certificate) VerifyHostname(h string) os.Error {
if len(c.DNSNames) > 0 {
for _, match := range c.DNSNames {
if matchHostnames(match, h) {
return nil
}
}
// If Subject Alt Name is given, we ignore the common name.
} else if matchHostnames(c.Subject.CommonName, h) {
return nil
}
return &HostnameError{c, h}
}
type UnhandledCriticalExtension struct{} type UnhandledCriticalExtension struct{}
func (h UnhandledCriticalExtension) String() string { func (h UnhandledCriticalExtension) String() string {
...@@ -558,7 +568,9 @@ func parsePublicKey(algo PublicKeyAlgorithm, asn1Data []byte) (interface{}, os.E ...@@ -558,7 +568,9 @@ func parsePublicKey(algo PublicKeyAlgorithm, asn1Data []byte) (interface{}, os.E
func parseCertificate(in *certificate) (*Certificate, os.Error) { func parseCertificate(in *certificate) (*Certificate, os.Error) {
out := new(Certificate) out := new(Certificate)
out.Raw = in.TBSCertificate.Raw out.Raw = in.Raw
out.RawTBSCertificate = in.TBSCertificate.Raw
out.RawSubjectPublicKeyInfo = in.TBSCertificate.PublicKey.Raw
out.Signature = in.SignatureValue.RightAlign() out.Signature = in.SignatureValue.RightAlign()
out.SignatureAlgorithm = out.SignatureAlgorithm =
...@@ -975,7 +987,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P ...@@ -975,7 +987,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P
Issuer: parent.Subject.toRDNSequence(), Issuer: parent.Subject.toRDNSequence(),
Validity: validity{template.NotBefore, template.NotAfter}, Validity: validity{template.NotBefore, template.NotAfter},
Subject: template.Subject.toRDNSequence(), Subject: template.Subject.toRDNSequence(),
PublicKey: publicKeyInfo{algorithmIdentifier{oidRSA}, encodedPublicKey}, PublicKey: publicKeyInfo{nil, algorithmIdentifier{oidRSA}, encodedPublicKey},
Extensions: extensions, Extensions: extensions,
} }
...@@ -996,6 +1008,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P ...@@ -996,6 +1008,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P
} }
cert, err = asn1.Marshal(certificate{ cert, err = asn1.Marshal(certificate{
nil,
c, c,
algorithmIdentifier{oidSHA1WithRSA}, algorithmIdentifier{oidSHA1WithRSA},
asn1.BitString{Bytes: signature, BitLength: len(signature) * 8}, asn1.BitString{Bytes: signature, BitLength: len(signature) * 8},
......
...@@ -20,12 +20,13 @@ func TestParsePKCS1PrivateKey(t *testing.T) { ...@@ -20,12 +20,13 @@ func TestParsePKCS1PrivateKey(t *testing.T) {
priv, err := ParsePKCS1PrivateKey(block.Bytes) priv, err := ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
t.Errorf("Failed to parse private key: %s", err) t.Errorf("Failed to parse private key: %s", err)
return
} }
if priv.PublicKey.N.Cmp(rsaPrivateKey.PublicKey.N) != 0 || if priv.PublicKey.N.Cmp(rsaPrivateKey.PublicKey.N) != 0 ||
priv.PublicKey.E != rsaPrivateKey.PublicKey.E || priv.PublicKey.E != rsaPrivateKey.PublicKey.E ||
priv.D.Cmp(rsaPrivateKey.D) != 0 || priv.D.Cmp(rsaPrivateKey.D) != 0 ||
priv.P.Cmp(rsaPrivateKey.P) != 0 || priv.Primes[0].Cmp(rsaPrivateKey.Primes[0]) != 0 ||
priv.Q.Cmp(rsaPrivateKey.Q) != 0 { priv.Primes[1].Cmp(rsaPrivateKey.Primes[1]) != 0 {
t.Errorf("got:%+v want:%+v", priv, rsaPrivateKey) t.Errorf("got:%+v want:%+v", priv, rsaPrivateKey)
} }
} }
...@@ -47,14 +48,54 @@ func bigFromString(s string) *big.Int { ...@@ -47,14 +48,54 @@ func bigFromString(s string) *big.Int {
return ret return ret
} }
func fromBase10(base10 string) *big.Int {
i := new(big.Int)
i.SetString(base10, 10)
return i
}
var rsaPrivateKey = &rsa.PrivateKey{ var rsaPrivateKey = &rsa.PrivateKey{
PublicKey: rsa.PublicKey{ PublicKey: rsa.PublicKey{
N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"), N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"),
E: 65537, E: 65537,
}, },
D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"), D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"),
P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"), Primes: []*big.Int{
Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"), bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"),
bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"),
},
}
func TestMarshalRSAPrivateKey(t *testing.T) {
priv := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: fromBase10("16346378922382193400538269749936049106320265317511766357599732575277382844051791096569333808598921852351577762718529818072849191122419410612033592401403764925096136759934497687765453905884149505175426053037420486697072448609022753683683718057795566811401938833367954642951433473337066311978821180526439641496973296037000052546108507805269279414789035461158073156772151892452251106173507240488993608650881929629163465099476849643165682709047462010581308719577053905787496296934240246311806555924593059995202856826239801816771116902778517096212527979497399966526283516447337775509777558018145573127308919204297111496233"),
E: 3,
},
D: fromBase10("10897585948254795600358846499957366070880176878341177571733155050184921896034527397712889205732614568234385175145686545381899460748279607074689061600935843283397424506622998458510302603922766336783617368686090042765718290914099334449154829375179958369993407724946186243249568928237086215759259909861748642124071874879861299389874230489928271621259294894142840428407196932444474088857746123104978617098858619445675532587787023228852383149557470077802718705420275739737958953794088728369933811184572620857678792001136676902250566845618813972833750098806496641114644760255910789397593428910198080271317419213080834885003"),
Primes: []*big.Int{
fromBase10("1025363189502892836833747188838978207017355117492483312747347695538428729137306368764177201532277413433182799108299960196606011786562992097313508180436744488171474690412562218914213688661311117337381958560443"),
fromBase10("3467903426626310123395340254094941045497208049900750380025518552334536945536837294961497712862519984786362199788654739924501424784631315081391467293694361474867825728031147665777546570788493758372218019373"),
fromBase10("4597024781409332673052708605078359346966325141767460991205742124888960305710298765592730135879076084498363772408626791576005136245060321874472727132746643162385746062759369754202494417496879741537284589047"),
},
}
derBytes := MarshalPKCS1PrivateKey(priv)
priv2, err := ParsePKCS1PrivateKey(derBytes)
if err != nil {
t.Errorf("error parsing serialized key: %s", err)
return
}
if priv.PublicKey.N.Cmp(priv2.PublicKey.N) != 0 ||
priv.PublicKey.E != priv2.PublicKey.E ||
priv.D.Cmp(priv2.D) != 0 ||
len(priv2.Primes) != 3 ||
priv.Primes[0].Cmp(priv2.Primes[0]) != 0 ||
priv.Primes[1].Cmp(priv2.Primes[1]) != 0 ||
priv.Primes[2].Cmp(priv2.Primes[2]) != 0 {
t.Errorf("got:%+v want:%+v", priv, priv2)
}
} }
type matchHostnamesTest struct { type matchHostnamesTest struct {
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements XTEA encryption, as defined in Needham and // Package xtea implements XTEA encryption, as defined in Needham and Wheeler's
// Wheeler's 1997 technical report, "Tea extensions." // 1997 technical report, "Tea extensions."
package xtea package xtea
// For details, see http://www.cix.co.uk/~klockstone/xtea.pdf // For details, see http://www.cix.co.uk/~klockstone/xtea.pdf
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package provides access to DWARF debugging information // Package dwarf provides access to DWARF debugging information loaded from
// loaded from executable files, as defined in the DWARF 2.0 Standard // executable files, as defined in the DWARF 2.0 Standard at
// at http://dwarfstd.org/doc/dwarf-2.0.0.pdf // http://dwarfstd.org/doc/dwarf-2.0.0.pdf
package dwarf package dwarf
import ( import (
......
...@@ -330,29 +330,35 @@ func (i SectionIndex) GoString() string { return stringName(uint32(i), shnString ...@@ -330,29 +330,35 @@ func (i SectionIndex) GoString() string { return stringName(uint32(i), shnString
type SectionType uint32 type SectionType uint32
const ( const (
SHT_NULL SectionType = 0 /* inactive */ SHT_NULL SectionType = 0 /* inactive */
SHT_PROGBITS SectionType = 1 /* program defined information */ SHT_PROGBITS SectionType = 1 /* program defined information */
SHT_SYMTAB SectionType = 2 /* symbol table section */ SHT_SYMTAB SectionType = 2 /* symbol table section */
SHT_STRTAB SectionType = 3 /* string table section */ SHT_STRTAB SectionType = 3 /* string table section */
SHT_RELA SectionType = 4 /* relocation section with addends */ SHT_RELA SectionType = 4 /* relocation section with addends */
SHT_HASH SectionType = 5 /* symbol hash table section */ SHT_HASH SectionType = 5 /* symbol hash table section */
SHT_DYNAMIC SectionType = 6 /* dynamic section */ SHT_DYNAMIC SectionType = 6 /* dynamic section */
SHT_NOTE SectionType = 7 /* note section */ SHT_NOTE SectionType = 7 /* note section */
SHT_NOBITS SectionType = 8 /* no space section */ SHT_NOBITS SectionType = 8 /* no space section */
SHT_REL SectionType = 9 /* relocation section - no addends */ SHT_REL SectionType = 9 /* relocation section - no addends */
SHT_SHLIB SectionType = 10 /* reserved - purpose unknown */ SHT_SHLIB SectionType = 10 /* reserved - purpose unknown */
SHT_DYNSYM SectionType = 11 /* dynamic symbol table section */ SHT_DYNSYM SectionType = 11 /* dynamic symbol table section */
SHT_INIT_ARRAY SectionType = 14 /* Initialization function pointers. */ SHT_INIT_ARRAY SectionType = 14 /* Initialization function pointers. */
SHT_FINI_ARRAY SectionType = 15 /* Termination function pointers. */ SHT_FINI_ARRAY SectionType = 15 /* Termination function pointers. */
SHT_PREINIT_ARRAY SectionType = 16 /* Pre-initialization function ptrs. */ SHT_PREINIT_ARRAY SectionType = 16 /* Pre-initialization function ptrs. */
SHT_GROUP SectionType = 17 /* Section group. */ SHT_GROUP SectionType = 17 /* Section group. */
SHT_SYMTAB_SHNDX SectionType = 18 /* Section indexes (see SHN_XINDEX). */ SHT_SYMTAB_SHNDX SectionType = 18 /* Section indexes (see SHN_XINDEX). */
SHT_LOOS SectionType = 0x60000000 /* First of OS specific semantics */ SHT_LOOS SectionType = 0x60000000 /* First of OS specific semantics */
SHT_HIOS SectionType = 0x6fffffff /* Last of OS specific semantics */ SHT_GNU_ATTRIBUTES SectionType = 0x6ffffff5 /* GNU object attributes */
SHT_LOPROC SectionType = 0x70000000 /* reserved range for processor */ SHT_GNU_HASH SectionType = 0x6ffffff6 /* GNU hash table */
SHT_HIPROC SectionType = 0x7fffffff /* specific section header types */ SHT_GNU_LIBLIST SectionType = 0x6ffffff7 /* GNU prelink library list */
SHT_LOUSER SectionType = 0x80000000 /* reserved range for application */ SHT_GNU_VERDEF SectionType = 0x6ffffffd /* GNU version definition section */
SHT_HIUSER SectionType = 0xffffffff /* specific indexes */ SHT_GNU_VERNEED SectionType = 0x6ffffffe /* GNU version needs section */
SHT_GNU_VERSYM SectionType = 0x6fffffff /* GNU version symbol table */
SHT_HIOS SectionType = 0x6fffffff /* Last of OS specific semantics */
SHT_LOPROC SectionType = 0x70000000 /* reserved range for processor */
SHT_HIPROC SectionType = 0x7fffffff /* specific section header types */
SHT_LOUSER SectionType = 0x80000000 /* reserved range for application */
SHT_HIUSER SectionType = 0xffffffff /* specific indexes */
) )
var shtStrings = []intName{ var shtStrings = []intName{
...@@ -374,7 +380,12 @@ var shtStrings = []intName{ ...@@ -374,7 +380,12 @@ var shtStrings = []intName{
{17, "SHT_GROUP"}, {17, "SHT_GROUP"},
{18, "SHT_SYMTAB_SHNDX"}, {18, "SHT_SYMTAB_SHNDX"},
{0x60000000, "SHT_LOOS"}, {0x60000000, "SHT_LOOS"},
{0x6fffffff, "SHT_HIOS"}, {0x6ffffff5, "SHT_GNU_ATTRIBUTES"},
{0x6ffffff6, "SHT_GNU_HASH"},
{0x6ffffff7, "SHT_GNU_LIBLIST"},
{0x6ffffffd, "SHT_GNU_VERDEF"},
{0x6ffffffe, "SHT_GNU_VERNEED"},
{0x6fffffff, "SHT_GNU_VERSYM"},
{0x70000000, "SHT_LOPROC"}, {0x70000000, "SHT_LOPROC"},
{0x7fffffff, "SHT_HIPROC"}, {0x7fffffff, "SHT_HIPROC"},
{0x80000000, "SHT_LOUSER"}, {0x80000000, "SHT_LOUSER"},
...@@ -518,6 +529,9 @@ const ( ...@@ -518,6 +529,9 @@ const (
DT_PREINIT_ARRAYSZ DynTag = 33 /* Size in bytes of the array of pre-initialization functions. */ DT_PREINIT_ARRAYSZ DynTag = 33 /* Size in bytes of the array of pre-initialization functions. */
DT_LOOS DynTag = 0x6000000d /* First OS-specific */ DT_LOOS DynTag = 0x6000000d /* First OS-specific */
DT_HIOS DynTag = 0x6ffff000 /* Last OS-specific */ DT_HIOS DynTag = 0x6ffff000 /* Last OS-specific */
DT_VERSYM DynTag = 0x6ffffff0
DT_VERNEED DynTag = 0x6ffffffe
DT_VERNEEDNUM DynTag = 0x6fffffff
DT_LOPROC DynTag = 0x70000000 /* First processor-specific type. */ DT_LOPROC DynTag = 0x70000000 /* First processor-specific type. */
DT_HIPROC DynTag = 0x7fffffff /* Last processor-specific type. */ DT_HIPROC DynTag = 0x7fffffff /* Last processor-specific type. */
) )
...@@ -559,6 +573,9 @@ var dtStrings = []intName{ ...@@ -559,6 +573,9 @@ var dtStrings = []intName{
{33, "DT_PREINIT_ARRAYSZ"}, {33, "DT_PREINIT_ARRAYSZ"},
{0x6000000d, "DT_LOOS"}, {0x6000000d, "DT_LOOS"},
{0x6ffff000, "DT_HIOS"}, {0x6ffff000, "DT_HIOS"},
{0x6ffffff0, "DT_VERSYM"},
{0x6ffffffe, "DT_VERNEED"},
{0x6fffffff, "DT_VERNEEDNUM"},
{0x70000000, "DT_LOPROC"}, {0x70000000, "DT_LOPROC"},
{0x7fffffff, "DT_HIPROC"}, {0x7fffffff, "DT_HIPROC"},
} }
......
...@@ -35,9 +35,11 @@ type FileHeader struct { ...@@ -35,9 +35,11 @@ type FileHeader struct {
// A File represents an open ELF file. // A File represents an open ELF file.
type File struct { type File struct {
FileHeader FileHeader
Sections []*Section Sections []*Section
Progs []*Prog Progs []*Prog
closer io.Closer closer io.Closer
gnuNeed []verneed
gnuVersym []byte
} }
// A SectionHeader represents a single ELF section header. // A SectionHeader represents a single ELF section header.
...@@ -329,8 +331,8 @@ func NewFile(r io.ReaderAt) (*File, os.Error) { ...@@ -329,8 +331,8 @@ func NewFile(r io.ReaderAt) (*File, os.Error) {
} }
// getSymbols returns a slice of Symbols from parsing the symbol table // getSymbols returns a slice of Symbols from parsing the symbol table
// with the given type. // with the given type, along with the associated string table.
func (f *File) getSymbols(typ SectionType) ([]Symbol, os.Error) { func (f *File) getSymbols(typ SectionType) ([]Symbol, []byte, os.Error) {
switch f.Class { switch f.Class {
case ELFCLASS64: case ELFCLASS64:
return f.getSymbols64(typ) return f.getSymbols64(typ)
...@@ -339,27 +341,27 @@ func (f *File) getSymbols(typ SectionType) ([]Symbol, os.Error) { ...@@ -339,27 +341,27 @@ func (f *File) getSymbols(typ SectionType) ([]Symbol, os.Error) {
return f.getSymbols32(typ) return f.getSymbols32(typ)
} }
return nil, os.ErrorString("not implemented") return nil, nil, os.ErrorString("not implemented")
} }
func (f *File) getSymbols32(typ SectionType) ([]Symbol, os.Error) { func (f *File) getSymbols32(typ SectionType) ([]Symbol, []byte, os.Error) {
symtabSection := f.SectionByType(typ) symtabSection := f.SectionByType(typ)
if symtabSection == nil { if symtabSection == nil {
return nil, os.ErrorString("no symbol section") return nil, nil, os.ErrorString("no symbol section")
} }
data, err := symtabSection.Data() data, err := symtabSection.Data()
if err != nil { if err != nil {
return nil, os.ErrorString("cannot load symbol section") return nil, nil, os.ErrorString("cannot load symbol section")
} }
symtab := bytes.NewBuffer(data) symtab := bytes.NewBuffer(data)
if symtab.Len()%Sym32Size != 0 { if symtab.Len()%Sym32Size != 0 {
return nil, os.ErrorString("length of symbol section is not a multiple of SymSize") return nil, nil, os.ErrorString("length of symbol section is not a multiple of SymSize")
} }
strdata, err := f.stringTable(symtabSection.Link) strdata, err := f.stringTable(symtabSection.Link)
if err != nil { if err != nil {
return nil, os.ErrorString("cannot load string table section") return nil, nil, os.ErrorString("cannot load string table section")
} }
// The first entry is all zeros. // The first entry is all zeros.
...@@ -382,27 +384,27 @@ func (f *File) getSymbols32(typ SectionType) ([]Symbol, os.Error) { ...@@ -382,27 +384,27 @@ func (f *File) getSymbols32(typ SectionType) ([]Symbol, os.Error) {
i++ i++
} }
return symbols, nil return symbols, strdata, nil
} }
func (f *File) getSymbols64(typ SectionType) ([]Symbol, os.Error) { func (f *File) getSymbols64(typ SectionType) ([]Symbol, []byte, os.Error) {
symtabSection := f.SectionByType(typ) symtabSection := f.SectionByType(typ)
if symtabSection == nil { if symtabSection == nil {
return nil, os.ErrorString("no symbol section") return nil, nil, os.ErrorString("no symbol section")
} }
data, err := symtabSection.Data() data, err := symtabSection.Data()
if err != nil { if err != nil {
return nil, os.ErrorString("cannot load symbol section") return nil, nil, os.ErrorString("cannot load symbol section")
} }
symtab := bytes.NewBuffer(data) symtab := bytes.NewBuffer(data)
if symtab.Len()%Sym64Size != 0 { if symtab.Len()%Sym64Size != 0 {
return nil, os.ErrorString("length of symbol section is not a multiple of Sym64Size") return nil, nil, os.ErrorString("length of symbol section is not a multiple of Sym64Size")
} }
strdata, err := f.stringTable(symtabSection.Link) strdata, err := f.stringTable(symtabSection.Link)
if err != nil { if err != nil {
return nil, os.ErrorString("cannot load string table section") return nil, nil, os.ErrorString("cannot load string table section")
} }
// The first entry is all zeros. // The first entry is all zeros.
...@@ -425,7 +427,7 @@ func (f *File) getSymbols64(typ SectionType) ([]Symbol, os.Error) { ...@@ -425,7 +427,7 @@ func (f *File) getSymbols64(typ SectionType) ([]Symbol, os.Error) {
i++ i++
} }
return symbols, nil return symbols, strdata, nil
} }
// getString extracts a string from an ELF string table. // getString extracts a string from an ELF string table.
...@@ -468,7 +470,7 @@ func (f *File) applyRelocationsAMD64(dst []byte, rels []byte) os.Error { ...@@ -468,7 +470,7 @@ func (f *File) applyRelocationsAMD64(dst []byte, rels []byte) os.Error {
return os.ErrorString("length of relocation section is not a multiple of Sym64Size") return os.ErrorString("length of relocation section is not a multiple of Sym64Size")
} }
symbols, err := f.getSymbols(SHT_SYMTAB) symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil { if err != nil {
return err return err
} }
...@@ -544,24 +546,123 @@ func (f *File) DWARF() (*dwarf.Data, os.Error) { ...@@ -544,24 +546,123 @@ func (f *File) DWARF() (*dwarf.Data, os.Error) {
return dwarf.New(abbrev, nil, nil, info, nil, nil, nil, str) return dwarf.New(abbrev, nil, nil, info, nil, nil, nil, str)
} }
type ImportedSymbol struct {
Name string
Version string
Library string
}
// ImportedSymbols returns the names of all symbols // ImportedSymbols returns the names of all symbols
// referred to by the binary f that are expected to be // referred to by the binary f that are expected to be
// satisfied by other libraries at dynamic load time. // satisfied by other libraries at dynamic load time.
// It does not return weak symbols. // It does not return weak symbols.
func (f *File) ImportedSymbols() ([]string, os.Error) { func (f *File) ImportedSymbols() ([]ImportedSymbol, os.Error) {
sym, err := f.getSymbols(SHT_DYNSYM) sym, str, err := f.getSymbols(SHT_DYNSYM)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var all []string f.gnuVersionInit(str)
for _, s := range sym { var all []ImportedSymbol
for i, s := range sym {
if ST_BIND(s.Info) == STB_GLOBAL && s.Section == SHN_UNDEF { if ST_BIND(s.Info) == STB_GLOBAL && s.Section == SHN_UNDEF {
all = append(all, s.Name) all = append(all, ImportedSymbol{Name: s.Name})
f.gnuVersion(i, &all[len(all)-1])
} }
} }
return all, nil return all, nil
} }
type verneed struct {
File string
Name string
}
// gnuVersionInit parses the GNU version tables
// for use by calls to gnuVersion.
func (f *File) gnuVersionInit(str []byte) {
// Accumulate verneed information.
vn := f.SectionByType(SHT_GNU_VERNEED)
if vn == nil {
return
}
d, _ := vn.Data()
var need []verneed
i := 0
for {
if i+16 > len(d) {
break
}
vers := f.ByteOrder.Uint16(d[i : i+2])
if vers != 1 {
break
}
cnt := f.ByteOrder.Uint16(d[i+2 : i+4])
fileoff := f.ByteOrder.Uint32(d[i+4 : i+8])
aux := f.ByteOrder.Uint32(d[i+8 : i+12])
next := f.ByteOrder.Uint32(d[i+12 : i+16])
file, _ := getString(str, int(fileoff))
var name string
j := i + int(aux)
for c := 0; c < int(cnt); c++ {
if j+16 > len(d) {
break
}
// hash := f.ByteOrder.Uint32(d[j:j+4])
// flags := f.ByteOrder.Uint16(d[j+4:j+6])
other := f.ByteOrder.Uint16(d[j+6 : j+8])
nameoff := f.ByteOrder.Uint32(d[j+8 : j+12])
next := f.ByteOrder.Uint32(d[j+12 : j+16])
name, _ = getString(str, int(nameoff))
ndx := int(other)
if ndx >= len(need) {
a := make([]verneed, 2*(ndx+1))
copy(a, need)
need = a
}
need[ndx] = verneed{file, name}
if next == 0 {
break
}
j += int(next)
}
if next == 0 {
break
}
i += int(next)
}
// Versym parallels symbol table, indexing into verneed.
vs := f.SectionByType(SHT_GNU_VERSYM)
if vs == nil {
return
}
d, _ = vs.Data()
f.gnuNeed = need
f.gnuVersym = d
}
// gnuVersion adds Library and Version information to sym,
// which came from offset i of the symbol table.
func (f *File) gnuVersion(i int, sym *ImportedSymbol) {
// Each entry is two bytes; skip undef entry at beginning.
i = (i + 1) * 2
if i >= len(f.gnuVersym) {
return
}
j := int(f.ByteOrder.Uint16(f.gnuVersym[i:]))
if j < 2 || j >= len(f.gnuNeed) {
return
}
n := &f.gnuNeed[j]
sym.Library = n.File
sym.Version = n.Name
}
// ImportedLibraries returns the names of all libraries // ImportedLibraries returns the names of all libraries
// referred to by the binary f that are expected to be // referred to by the binary f that are expected to be
// linked with the binary at dynamic link time. // linked with the binary at dynamic link time.
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// A library for EBNF grammars. The input is text ([]byte) satisfying // Package ebnf is a library for EBNF grammars. The input is text ([]byte)
// the following grammar (represented itself in EBNF): // satisfying the following grammar (represented itself in EBNF):
// //
// Production = name "=" Expression "." . // Production = name "=" Expression "." .
// Expression = Alternative { "|" Alternative } . // Expression = Alternative { "|" Alternative } .
......
...@@ -126,10 +126,10 @@ func (bigEndian) GoString() string { return "binary.BigEndian" } ...@@ -126,10 +126,10 @@ func (bigEndian) GoString() string { return "binary.BigEndian" }
// and written to successive fields of the data. // and written to successive fields of the data.
func Read(r io.Reader, order ByteOrder, data interface{}) os.Error { func Read(r io.Reader, order ByteOrder, data interface{}) os.Error {
var v reflect.Value var v reflect.Value
switch d := reflect.NewValue(data).(type) { switch d := reflect.ValueOf(data); d.Kind() {
case *reflect.PtrValue: case reflect.Ptr:
v = d.Elem() v = d.Elem()
case *reflect.SliceValue: case reflect.Slice:
v = d v = d
default: default:
return os.NewError("binary.Read: invalid type " + d.Type().String()) return os.NewError("binary.Read: invalid type " + d.Type().String())
...@@ -155,7 +155,7 @@ func Read(r io.Reader, order ByteOrder, data interface{}) os.Error { ...@@ -155,7 +155,7 @@ func Read(r io.Reader, order ByteOrder, data interface{}) os.Error {
// Bytes written to w are encoded using the specified byte order // Bytes written to w are encoded using the specified byte order
// and read from successive fields of the data. // and read from successive fields of the data.
func Write(w io.Writer, order ByteOrder, data interface{}) os.Error { func Write(w io.Writer, order ByteOrder, data interface{}) os.Error {
v := reflect.Indirect(reflect.NewValue(data)) v := reflect.Indirect(reflect.ValueOf(data))
size := TotalSize(v) size := TotalSize(v)
if size < 0 { if size < 0 {
return os.NewError("binary.Write: invalid type " + v.Type().String()) return os.NewError("binary.Write: invalid type " + v.Type().String())
...@@ -168,26 +168,26 @@ func Write(w io.Writer, order ByteOrder, data interface{}) os.Error { ...@@ -168,26 +168,26 @@ func Write(w io.Writer, order ByteOrder, data interface{}) os.Error {
} }
func TotalSize(v reflect.Value) int { func TotalSize(v reflect.Value) int {
if sv, ok := v.(*reflect.SliceValue); ok { if v.Kind() == reflect.Slice {
elem := sizeof(v.Type().(*reflect.SliceType).Elem()) elem := sizeof(v.Type().Elem())
if elem < 0 { if elem < 0 {
return -1 return -1
} }
return sv.Len() * elem return v.Len() * elem
} }
return sizeof(v.Type()) return sizeof(v.Type())
} }
func sizeof(v reflect.Type) int { func sizeof(t reflect.Type) int {
switch t := v.(type) { switch t.Kind() {
case *reflect.ArrayType: case reflect.Array:
n := sizeof(t.Elem()) n := sizeof(t.Elem())
if n < 0 { if n < 0 {
return -1 return -1
} }
return t.Len() * n return t.Len() * n
case *reflect.StructType: case reflect.Struct:
sum := 0 sum := 0
for i, n := 0, t.NumField(); i < n; i++ { for i, n := 0, t.NumField(); i < n; i++ {
s := sizeof(t.Field(i).Type) s := sizeof(t.Field(i).Type)
...@@ -198,12 +198,10 @@ func sizeof(v reflect.Type) int { ...@@ -198,12 +198,10 @@ func sizeof(v reflect.Type) int {
} }
return sum return sum
case *reflect.UintType, *reflect.IntType, *reflect.FloatType, *reflect.ComplexType: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
switch t := t.Kind(); t { reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
case reflect.Int, reflect.Uint, reflect.Uintptr: reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return -1 return int(t.Size())
}
return int(v.Size())
} }
return -1 return -1
} }
...@@ -279,130 +277,118 @@ func (d *decoder) int64() int64 { return int64(d.uint64()) } ...@@ -279,130 +277,118 @@ func (d *decoder) int64() int64 { return int64(d.uint64()) }
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) } func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
func (d *decoder) value(v reflect.Value) { func (d *decoder) value(v reflect.Value) {
switch v := v.(type) { switch v.Kind() {
case *reflect.ArrayValue: case reflect.Array:
l := v.Len() l := v.Len()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
d.value(v.Elem(i)) d.value(v.Index(i))
} }
case *reflect.StructValue: case reflect.Struct:
l := v.NumField() l := v.NumField()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
d.value(v.Field(i)) d.value(v.Field(i))
} }
case *reflect.SliceValue: case reflect.Slice:
l := v.Len() l := v.Len()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
d.value(v.Elem(i)) d.value(v.Index(i))
} }
case *reflect.IntValue: case reflect.Int8:
switch v.Type().Kind() { v.SetInt(int64(d.int8()))
case reflect.Int8: case reflect.Int16:
v.Set(int64(d.int8())) v.SetInt(int64(d.int16()))
case reflect.Int16: case reflect.Int32:
v.Set(int64(d.int16())) v.SetInt(int64(d.int32()))
case reflect.Int32: case reflect.Int64:
v.Set(int64(d.int32())) v.SetInt(d.int64())
case reflect.Int64:
v.Set(d.int64()) case reflect.Uint8:
} v.SetUint(uint64(d.uint8()))
case reflect.Uint16:
case *reflect.UintValue: v.SetUint(uint64(d.uint16()))
switch v.Type().Kind() { case reflect.Uint32:
case reflect.Uint8: v.SetUint(uint64(d.uint32()))
v.Set(uint64(d.uint8())) case reflect.Uint64:
case reflect.Uint16: v.SetUint(d.uint64())
v.Set(uint64(d.uint16()))
case reflect.Uint32: case reflect.Float32:
v.Set(uint64(d.uint32())) v.SetFloat(float64(math.Float32frombits(d.uint32())))
case reflect.Uint64: case reflect.Float64:
v.Set(d.uint64()) v.SetFloat(math.Float64frombits(d.uint64()))
}
case reflect.Complex64:
case *reflect.FloatValue: v.SetComplex(complex(
switch v.Type().Kind() { float64(math.Float32frombits(d.uint32())),
case reflect.Float32: float64(math.Float32frombits(d.uint32())),
v.Set(float64(math.Float32frombits(d.uint32()))) ))
case reflect.Float64: case reflect.Complex128:
v.Set(math.Float64frombits(d.uint64())) v.SetComplex(complex(
} math.Float64frombits(d.uint64()),
math.Float64frombits(d.uint64()),
case *reflect.ComplexValue: ))
switch v.Type().Kind() {
case reflect.Complex64:
v.Set(complex(
float64(math.Float32frombits(d.uint32())),
float64(math.Float32frombits(d.uint32())),
))
case reflect.Complex128:
v.Set(complex(
math.Float64frombits(d.uint64()),
math.Float64frombits(d.uint64()),
))
}
} }
} }
func (e *encoder) value(v reflect.Value) { func (e *encoder) value(v reflect.Value) {
switch v := v.(type) { switch v.Kind() {
case *reflect.ArrayValue: case reflect.Array:
l := v.Len() l := v.Len()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
e.value(v.Elem(i)) e.value(v.Index(i))
} }
case *reflect.StructValue: case reflect.Struct:
l := v.NumField() l := v.NumField()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
e.value(v.Field(i)) e.value(v.Field(i))
} }
case *reflect.SliceValue: case reflect.Slice:
l := v.Len() l := v.Len()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
e.value(v.Elem(i)) e.value(v.Index(i))
} }
case *reflect.IntValue: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch v.Type().Kind() { switch v.Type().Kind() {
case reflect.Int8: case reflect.Int8:
e.int8(int8(v.Get())) e.int8(int8(v.Int()))
case reflect.Int16: case reflect.Int16:
e.int16(int16(v.Get())) e.int16(int16(v.Int()))
case reflect.Int32: case reflect.Int32:
e.int32(int32(v.Get())) e.int32(int32(v.Int()))
case reflect.Int64: case reflect.Int64:
e.int64(v.Get()) e.int64(v.Int())
} }
case *reflect.UintValue: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
switch v.Type().Kind() { switch v.Type().Kind() {
case reflect.Uint8: case reflect.Uint8:
e.uint8(uint8(v.Get())) e.uint8(uint8(v.Uint()))
case reflect.Uint16: case reflect.Uint16:
e.uint16(uint16(v.Get())) e.uint16(uint16(v.Uint()))
case reflect.Uint32: case reflect.Uint32:
e.uint32(uint32(v.Get())) e.uint32(uint32(v.Uint()))
case reflect.Uint64: case reflect.Uint64:
e.uint64(v.Get()) e.uint64(v.Uint())
} }
case *reflect.FloatValue: case reflect.Float32, reflect.Float64:
switch v.Type().Kind() { switch v.Type().Kind() {
case reflect.Float32: case reflect.Float32:
e.uint32(math.Float32bits(float32(v.Get()))) e.uint32(math.Float32bits(float32(v.Float())))
case reflect.Float64: case reflect.Float64:
e.uint64(math.Float64bits(v.Get())) e.uint64(math.Float64bits(v.Float()))
} }
case *reflect.ComplexValue: case reflect.Complex64, reflect.Complex128:
switch v.Type().Kind() { switch v.Type().Kind() {
case reflect.Complex64: case reflect.Complex64:
x := v.Get() x := v.Complex()
e.uint32(math.Float32bits(float32(real(x)))) e.uint32(math.Float32bits(float32(real(x))))
e.uint32(math.Float32bits(float32(imag(x)))) e.uint32(math.Float32bits(float32(imag(x))))
case reflect.Complex128: case reflect.Complex128:
x := v.Get() x := v.Complex()
e.uint64(math.Float64bits(real(x))) e.uint64(math.Float64bits(real(x)))
e.uint64(math.Float64bits(imag(x))) e.uint64(math.Float64bits(imag(x)))
} }
......
...@@ -152,7 +152,7 @@ func TestWriteT(t *testing.T) { ...@@ -152,7 +152,7 @@ func TestWriteT(t *testing.T) {
t.Errorf("WriteT: have nil, want non-nil") t.Errorf("WriteT: have nil, want non-nil")
} }
tv := reflect.Indirect(reflect.NewValue(ts)).(*reflect.StructValue) tv := reflect.Indirect(reflect.ValueOf(ts))
for i, n := 0, tv.NumField(); i < n; i++ { for i, n := 0, tv.NumField(); i < n; i++ {
err = Write(buf, BigEndian, tv.Field(i).Interface()) err = Write(buf, BigEndian, tv.Field(i).Interface())
if err == nil { if err == nil {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements hexadecimal encoding and decoding. // Package hex implements hexadecimal encoding and decoding.
package hex package hex
import ( import (
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The line package implements a Reader that reads lines delimited by '\n' or ' \r\n'. // Package line implements a Reader that reads lines delimited by '\n' or
// ' \r\n'.
package line package line
import ( import (
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package implements the PEM data encoding, which originated in Privacy // Package pem implements the PEM data encoding, which originated in Privacy
// Enhanced Mail. The most common use of PEM encoding today is in TLS keys and // Enhanced Mail. The most common use of PEM encoding today is in TLS keys and
// certificates. See RFC 1421. // certificates. See RFC 1421.
package pem package pem
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The exec package runs external commands. It wraps os.StartProcess // Package exec runs external commands. It wraps os.StartProcess to make it
// to make it easier to remap stdin and stdout, connect I/O with pipes, // easier to remap stdin and stdout, connect I/O with pipes, and do other
// and do other adjustments. // adjustments.
package exec package exec
// BUG(r): This package should be made even easier to use or merged into os. // BUG(r): This package should be made even easier to use or merged into os.
......
...@@ -9,19 +9,14 @@ import ( ...@@ -9,19 +9,14 @@ import (
"io/ioutil" "io/ioutil"
"testing" "testing"
"os" "os"
"runtime"
) )
func run(argv []string, stdin, stdout, stderr int) (p *Cmd, err os.Error) { func run(argv []string, stdin, stdout, stderr int) (p *Cmd, err os.Error) {
if runtime.GOOS == "windows" {
argv = append([]string{"cmd", "/c"}, argv...)
}
exe, err := LookPath(argv[0]) exe, err := LookPath(argv[0])
if err != nil { if err != nil {
return nil, err return nil, err
} }
p, err = Run(exe, argv, nil, "", stdin, stdout, stderr) return Run(exe, argv, nil, "", stdin, stdout, stderr)
return p, err
} }
func TestRunCat(t *testing.T) { func TestRunCat(t *testing.T) {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
/* The datafmt package implements syntax-directed, type-driven formatting /* Package datafmt implements syntax-directed, type-driven formatting
of arbitrary data structures. Formatting a data structure consists of of arbitrary data structures. Formatting a data structure consists of
two phases: first, a parser reads a format specification and builds a two phases: first, a parser reads a format specification and builds a
"compiled" format. Then, the format can be applied repeatedly to "compiled" format. Then, the format can be applied repeatedly to
...@@ -408,20 +408,20 @@ func (s *State) error(msg string) { ...@@ -408,20 +408,20 @@ func (s *State) error(msg string) {
// //
func typename(typ reflect.Type) string { func typename(typ reflect.Type) string {
switch typ.(type) { switch typ.Kind() {
case *reflect.ArrayType: case reflect.Array:
return "array" return "array"
case *reflect.SliceType: case reflect.Slice:
return "array" return "array"
case *reflect.ChanType: case reflect.Chan:
return "chan" return "chan"
case *reflect.FuncType: case reflect.Func:
return "func" return "func"
case *reflect.InterfaceType: case reflect.Interface:
return "interface" return "interface"
case *reflect.MapType: case reflect.Map:
return "map" return "map"
case *reflect.PtrType: case reflect.Ptr:
return "ptr" return "ptr"
} }
return typ.String() return typ.String()
...@@ -519,38 +519,38 @@ func (s *State) eval(fexpr expr, value reflect.Value, index int) bool { ...@@ -519,38 +519,38 @@ func (s *State) eval(fexpr expr, value reflect.Value, index int) bool {
case "*": case "*":
// indirection: operation is type-specific // indirection: operation is type-specific
switch v := value.(type) { switch v := value; v.Kind() {
case *reflect.ArrayValue: case reflect.Array:
if v.Len() <= index { if v.Len() <= index {
return false return false
} }
value = v.Elem(index) value = v.Index(index)
case *reflect.SliceValue: case reflect.Slice:
if v.IsNil() || v.Len() <= index { if v.IsNil() || v.Len() <= index {
return false return false
} }
value = v.Elem(index) value = v.Index(index)
case *reflect.MapValue: case reflect.Map:
s.error("reflection support for maps incomplete") s.error("reflection support for maps incomplete")
case *reflect.PtrValue: case reflect.Ptr:
if v.IsNil() { if v.IsNil() {
return false return false
} }
value = v.Elem() value = v.Elem()
case *reflect.InterfaceValue: case reflect.Interface:
if v.IsNil() { if v.IsNil() {
return false return false
} }
value = v.Elem() value = v.Elem()
case *reflect.ChanValue: case reflect.Chan:
s.error("reflection support for chans incomplete") s.error("reflection support for chans incomplete")
case *reflect.FuncValue: case reflect.Func:
s.error("reflection support for funcs incomplete") s.error("reflection support for funcs incomplete")
default: default:
...@@ -560,9 +560,9 @@ func (s *State) eval(fexpr expr, value reflect.Value, index int) bool { ...@@ -560,9 +560,9 @@ func (s *State) eval(fexpr expr, value reflect.Value, index int) bool {
default: default:
// value is value of named field // value is value of named field
var field reflect.Value var field reflect.Value
if sval, ok := value.(*reflect.StructValue); ok { if sval := value; sval.Kind() == reflect.Struct {
field = sval.FieldByName(t.fieldName) field = sval.FieldByName(t.fieldName)
if field == nil { if !field.IsValid() {
// TODO consider just returning false in this case // TODO consider just returning false in this case
s.error(fmt.Sprintf("error: no field `%s` in `%s`", t.fieldName, value.Type())) s.error(fmt.Sprintf("error: no field `%s` in `%s`", t.fieldName, value.Type()))
} }
...@@ -671,8 +671,8 @@ func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) { ...@@ -671,8 +671,8 @@ func (f Format) Eval(env Environment, args ...interface{}) ([]byte, os.Error) {
go func() { go func() {
for _, v := range args { for _, v := range args {
fld := reflect.NewValue(v) fld := reflect.ValueOf(v)
if fld == nil { if !fld.IsValid() {
errors <- os.NewError("nil argument") errors <- os.NewError("nil argument")
return return
} }
......
...@@ -8,7 +8,10 @@ ...@@ -8,7 +8,10 @@
// and the X Render extension. // and the X Render extension.
package draw package draw
import "image" import (
"image"
"image/ycbcr"
)
// m is the maximum color value returned by image.Color.RGBA. // m is the maximum color value returned by image.Color.RGBA.
const m = 1<<16 - 1 const m = 1<<16 - 1
...@@ -65,29 +68,42 @@ func DrawMask(dst Image, r image.Rectangle, src image.Image, sp image.Point, mas ...@@ -65,29 +68,42 @@ func DrawMask(dst Image, r image.Rectangle, src image.Image, sp image.Point, mas
if dst0, ok := dst.(*image.RGBA); ok { if dst0, ok := dst.(*image.RGBA); ok {
if op == Over { if op == Over {
if mask == nil { if mask == nil {
if src0, ok := src.(*image.ColorImage); ok { switch src0 := src.(type) {
case *image.ColorImage:
drawFillOver(dst0, r, src0) drawFillOver(dst0, r, src0)
return return
} case *image.RGBA:
if src0, ok := src.(*image.RGBA); ok {
drawCopyOver(dst0, r, src0, sp) drawCopyOver(dst0, r, src0, sp)
return return
case *image.NRGBA:
drawNRGBAOver(dst0, r, src0, sp)
return
case *ycbcr.YCbCr:
drawYCbCr(dst0, r, src0, sp)
return
} }
} else if mask0, ok := mask.(*image.Alpha); ok { } else if mask0, ok := mask.(*image.Alpha); ok {
if src0, ok := src.(*image.ColorImage); ok { switch src0 := src.(type) {
case *image.ColorImage:
drawGlyphOver(dst0, r, src0, mask0, mp) drawGlyphOver(dst0, r, src0, mask0, mp)
return return
} }
} }
} else { } else {
if mask == nil { if mask == nil {
if src0, ok := src.(*image.ColorImage); ok { switch src0 := src.(type) {
case *image.ColorImage:
drawFillSrc(dst0, r, src0) drawFillSrc(dst0, r, src0)
return return
} case *image.RGBA:
if src0, ok := src.(*image.RGBA); ok {
drawCopySrc(dst0, r, src0, sp) drawCopySrc(dst0, r, src0, sp)
return return
case *image.NRGBA:
drawNRGBASrc(dst0, r, src0, sp)
return
case *ycbcr.YCbCr:
drawYCbCr(dst0, r, src0, sp)
return
} }
} }
} }
...@@ -224,6 +240,36 @@ func drawCopyOver(dst *image.RGBA, r image.Rectangle, src *image.RGBA, sp image. ...@@ -224,6 +240,36 @@ func drawCopyOver(dst *image.RGBA, r image.Rectangle, src *image.RGBA, sp image.
} }
} }
func drawNRGBAOver(dst *image.RGBA, r image.Rectangle, src *image.NRGBA, sp image.Point) {
for y, sy := r.Min.Y, sp.Y; y != r.Max.Y; y, sy = y+1, sy+1 {
dpix := dst.Pix[y*dst.Stride : (y+1)*dst.Stride]
spix := src.Pix[sy*src.Stride : (sy+1)*src.Stride]
for x, sx := r.Min.X, sp.X; x != r.Max.X; x, sx = x+1, sx+1 {
// Convert from non-premultiplied color to pre-multiplied color.
// The order of operations here is to match the NRGBAColor.RGBA
// method in image/color.go.
snrgba := spix[sx]
sa := uint32(snrgba.A)
sr := uint32(snrgba.R) * 0x101 * sa / 0xff
sg := uint32(snrgba.G) * 0x101 * sa / 0xff
sb := uint32(snrgba.B) * 0x101 * sa / 0xff
sa *= 0x101
rgba := dpix[x]
dr := uint32(rgba.R)
dg := uint32(rgba.G)
db := uint32(rgba.B)
da := uint32(rgba.A)
a := (m - sa) * 0x101
dr = (dr*a + sr*m) / m
dg = (dg*a + sg*m) / m
db = (db*a + sb*m) / m
da = (da*a + sa*m) / m
dpix[x] = image.RGBAColor{uint8(dr >> 8), uint8(dg >> 8), uint8(db >> 8), uint8(da >> 8)}
}
}
}
func drawGlyphOver(dst *image.RGBA, r image.Rectangle, src *image.ColorImage, mask *image.Alpha, mp image.Point) { func drawGlyphOver(dst *image.RGBA, r image.Rectangle, src *image.ColorImage, mask *image.Alpha, mp image.Point) {
x0, x1 := r.Min.X, r.Max.X x0, x1 := r.Min.X, r.Max.X
y0, y1 := r.Min.Y, r.Max.Y y0, y1 := r.Min.Y, r.Max.Y
...@@ -311,6 +357,73 @@ func drawCopySrc(dst *image.RGBA, r image.Rectangle, src *image.RGBA, sp image.P ...@@ -311,6 +357,73 @@ func drawCopySrc(dst *image.RGBA, r image.Rectangle, src *image.RGBA, sp image.P
} }
} }
func drawNRGBASrc(dst *image.RGBA, r image.Rectangle, src *image.NRGBA, sp image.Point) {
for y, sy := r.Min.Y, sp.Y; y != r.Max.Y; y, sy = y+1, sy+1 {
dpix := dst.Pix[y*dst.Stride : (y+1)*dst.Stride]
spix := src.Pix[sy*src.Stride : (sy+1)*src.Stride]
for x, sx := r.Min.X, sp.X; x != r.Max.X; x, sx = x+1, sx+1 {
// Convert from non-premultiplied color to pre-multiplied color.
// The order of operations here is to match the NRGBAColor.RGBA
// method in image/color.go.
snrgba := spix[sx]
sa := uint32(snrgba.A)
sr := uint32(snrgba.R) * 0x101 * sa / 0xff
sg := uint32(snrgba.G) * 0x101 * sa / 0xff
sb := uint32(snrgba.B) * 0x101 * sa / 0xff
sa *= 0x101
dpix[x] = image.RGBAColor{uint8(sr >> 8), uint8(sg >> 8), uint8(sb >> 8), uint8(sa >> 8)}
}
}
}
func drawYCbCr(dst *image.RGBA, r image.Rectangle, src *ycbcr.YCbCr, sp image.Point) {
// A YCbCr image is always fully opaque, and so if the mask is implicitly nil
// (i.e. fully opaque) then the op is effectively always Src.
var (
yy, cb, cr uint8
rr, gg, bb uint8
)
switch src.SubsampleRatio {
case ycbcr.SubsampleRatio422:
for y, sy := r.Min.Y, sp.Y; y != r.Max.Y; y, sy = y+1, sy+1 {
dpix := dst.Pix[y*dst.Stride : (y+1)*dst.Stride]
for x, sx := r.Min.X, sp.X; x != r.Max.X; x, sx = x+1, sx+1 {
i := sx / 2
yy = src.Y[sy*src.YStride+sx]
cb = src.Cb[sy*src.CStride+i]
cr = src.Cr[sy*src.CStride+i]
rr, gg, bb = ycbcr.YCbCrToRGB(yy, cb, cr)
dpix[x] = image.RGBAColor{rr, gg, bb, 255}
}
}
case ycbcr.SubsampleRatio420:
for y, sy := r.Min.Y, sp.Y; y != r.Max.Y; y, sy = y+1, sy+1 {
dpix := dst.Pix[y*dst.Stride : (y+1)*dst.Stride]
for x, sx := r.Min.X, sp.X; x != r.Max.X; x, sx = x+1, sx+1 {
i, j := sx/2, sy/2
yy = src.Y[sy*src.YStride+sx]
cb = src.Cb[j*src.CStride+i]
cr = src.Cr[j*src.CStride+i]
rr, gg, bb = ycbcr.YCbCrToRGB(yy, cb, cr)
dpix[x] = image.RGBAColor{rr, gg, bb, 255}
}
}
default:
// Default to 4:4:4 subsampling.
for y, sy := r.Min.Y, sp.Y; y != r.Max.Y; y, sy = y+1, sy+1 {
dpix := dst.Pix[y*dst.Stride : (y+1)*dst.Stride]
for x, sx := r.Min.X, sp.X; x != r.Max.X; x, sx = x+1, sx+1 {
yy = src.Y[sy*src.YStride+sx]
cb = src.Cb[sy*src.CStride+sx]
cr = src.Cr[sy*src.CStride+sx]
rr, gg, bb = ycbcr.YCbCrToRGB(yy, cb, cr)
dpix[x] = image.RGBAColor{rr, gg, bb, 255}
}
}
}
}
func drawRGBA(dst *image.RGBA, r image.Rectangle, src image.Image, sp image.Point, mask image.Image, mp image.Point, op Op) { func drawRGBA(dst *image.RGBA, r image.Rectangle, src image.Image, sp image.Point, mask image.Image, mp image.Point, op Op) {
x0, x1, dx := r.Min.X, r.Max.X, 1 x0, x1, dx := r.Min.X, r.Max.X, 1
y0, y1, dy := r.Min.Y, r.Max.Y, 1 y0, y1, dy := r.Min.Y, r.Max.Y, 1
......
...@@ -6,6 +6,7 @@ package draw ...@@ -6,6 +6,7 @@ package draw
import ( import (
"image" "image"
"image/ycbcr"
"testing" "testing"
) )
...@@ -43,6 +44,34 @@ func vgradAlpha(alpha int) image.Image { ...@@ -43,6 +44,34 @@ func vgradAlpha(alpha int) image.Image {
return m return m
} }
func vgradGreenNRGBA(alpha int) image.Image {
m := image.NewNRGBA(16, 16)
for y := 0; y < 16; y++ {
for x := 0; x < 16; x++ {
m.Set(x, y, image.RGBAColor{0, uint8(y * 0x11), 0, uint8(alpha)})
}
}
return m
}
func vgradCr() image.Image {
m := &ycbcr.YCbCr{
Y: make([]byte, 16*16),
Cb: make([]byte, 16*16),
Cr: make([]byte, 16*16),
YStride: 16,
CStride: 16,
SubsampleRatio: ycbcr.SubsampleRatio444,
Rect: image.Rect(0, 0, 16, 16),
}
for y := 0; y < 16; y++ {
for x := 0; x < 16; x++ {
m.Cr[y*m.CStride+x] = uint8(y * 0x11)
}
}
return m
}
func hgradRed(alpha int) Image { func hgradRed(alpha int) Image {
m := image.NewRGBA(16, 16) m := image.NewRGBA(16, 16)
for y := 0; y < 16; y++ { for y := 0; y < 16; y++ {
...@@ -95,6 +124,27 @@ var drawTests = []drawTest{ ...@@ -95,6 +124,27 @@ var drawTests = []drawTest{
{"copyAlphaSrc", vgradGreen(90), fillAlpha(192), Src, image.RGBAColor{0, 36, 0, 68}}, {"copyAlphaSrc", vgradGreen(90), fillAlpha(192), Src, image.RGBAColor{0, 36, 0, 68}},
{"copyNil", vgradGreen(90), nil, Over, image.RGBAColor{88, 48, 0, 255}}, {"copyNil", vgradGreen(90), nil, Over, image.RGBAColor{88, 48, 0, 255}},
{"copyNilSrc", vgradGreen(90), nil, Src, image.RGBAColor{0, 48, 0, 90}}, {"copyNilSrc", vgradGreen(90), nil, Src, image.RGBAColor{0, 48, 0, 90}},
// Uniform mask (100%, 75%, nil) and variable NRGBA source.
// At (x, y) == (8, 8):
// The destination pixel is {136, 0, 0, 255}.
// The source pixel is {0, 136, 0, 90} in NRGBA-space, which is {0, 48, 0, 90} in RGBA-space.
// The result pixel is different than in the "copy*" test cases because of rounding errors.
{"nrgba", vgradGreenNRGBA(90), fillAlpha(255), Over, image.RGBAColor{88, 46, 0, 255}},
{"nrgbaSrc", vgradGreenNRGBA(90), fillAlpha(255), Src, image.RGBAColor{0, 46, 0, 90}},
{"nrgbaAlpha", vgradGreenNRGBA(90), fillAlpha(192), Over, image.RGBAColor{100, 34, 0, 255}},
{"nrgbaAlphaSrc", vgradGreenNRGBA(90), fillAlpha(192), Src, image.RGBAColor{0, 34, 0, 68}},
{"nrgbaNil", vgradGreenNRGBA(90), nil, Over, image.RGBAColor{88, 46, 0, 255}},
{"nrgbaNilSrc", vgradGreenNRGBA(90), nil, Src, image.RGBAColor{0, 46, 0, 90}},
// Uniform mask (100%, 75%, nil) and variable YCbCr source.
// At (x, y) == (8, 8):
// The destination pixel is {136, 0, 0, 255}.
// The source pixel is {0, 0, 136} in YCbCr-space, which is {11, 38, 0, 255} in RGB-space.
{"ycbcr", vgradCr(), fillAlpha(255), Over, image.RGBAColor{11, 38, 0, 255}},
{"ycbcrSrc", vgradCr(), fillAlpha(255), Src, image.RGBAColor{11, 38, 0, 255}},
{"ycbcrAlpha", vgradCr(), fillAlpha(192), Over, image.RGBAColor{42, 28, 0, 255}},
{"ycbcrAlphaSrc", vgradCr(), fillAlpha(192), Src, image.RGBAColor{8, 28, 0, 192}},
{"ycbcrNil", vgradCr(), nil, Over, image.RGBAColor{11, 38, 0, 255}},
{"ycbcrNilSrc", vgradCr(), nil, Src, image.RGBAColor{11, 38, 0, 255}},
// Variable mask and variable source. // Variable mask and variable source.
// At (x, y) == (8, 8): // At (x, y) == (8, 8):
// The destination pixel is {136, 0, 0, 255}. // The destination pixel is {136, 0, 0, 255}.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment