Commit 02bb39f4 by Edward Thomson

stream registration: take an enum type

Accept an enum (`git_stream_t`) during custom stream registration that
indicates whether the registration structure should be used for standard
(non-TLS) streams or TLS streams.
parent 52478d7d
...@@ -72,19 +72,31 @@ typedef struct { ...@@ -72,19 +72,31 @@ typedef struct {
} git_stream_registration; } git_stream_registration;
/** /**
* The type of stream to register.
*/
typedef enum {
/** A standard (non-TLS) socket. */
GIT_STREAM_STANDARD = 1,
/** A TLS-encrypted socket. */
GIT_STREAM_TLS = 2,
} git_stream_t;
/**
* Register stream constructors for the library to use * Register stream constructors for the library to use
* *
* If a registration structure is already set, it will be overwritten. * If a registration structure is already set, it will be overwritten.
* Pass `NULL` in order to deregister the current constructor and return * Pass `NULL` in order to deregister the current constructor and return
* to the system defaults. * to the system defaults.
* *
* The type parameter may be a bitwise AND of types.
*
* @param type the type or types of stream to register
* @param registration the registration data * @param registration the registration data
* @param tls 1 if the registration is for TLS streams, 0 for regular
* (insecure) sockets
* @return 0 or an error code * @return 0 or an error code
*/ */
GIT_EXTERN(int) git_stream_register( GIT_EXTERN(int) git_stream_register(
int tls, git_stream_registration *registration); git_stream_t type, git_stream_registration *registration);
/** @name Deprecated TLS Stream Registration Functions /** @name Deprecated TLS Stream Registration Functions
* *
......
...@@ -36,22 +36,42 @@ int git_stream_registry_global_init(void) ...@@ -36,22 +36,42 @@ int git_stream_registry_global_init(void)
return 0; return 0;
} }
int git_stream_registry_lookup(git_stream_registration *out, int tls) GIT_INLINE(void) stream_registration_cpy(
git_stream_registration *target,
git_stream_registration *src)
{ {
git_stream_registration *target = tls ? if (src)
&stream_registry.callbacks : memcpy(target, src, sizeof(git_stream_registration));
&stream_registry.tls_callbacks; else
memset(target, 0, sizeof(git_stream_registration));
}
int git_stream_registry_lookup(git_stream_registration *out, git_stream_t type)
{
git_stream_registration *target;
int error = GIT_ENOTFOUND; int error = GIT_ENOTFOUND;
assert(out); assert(out);
switch(type) {
case GIT_STREAM_STANDARD:
target = &stream_registry.callbacks;
break;
case GIT_STREAM_TLS:
target = &stream_registry.tls_callbacks;
break;
default:
assert(0);
return -1;
}
if (git_rwlock_rdlock(&stream_registry.lock) < 0) { if (git_rwlock_rdlock(&stream_registry.lock) < 0) {
giterr_set(GITERR_OS, "failed to lock stream registry"); giterr_set(GITERR_OS, "failed to lock stream registry");
return -1; return -1;
} }
if (target->init) { if (target->init) {
memcpy(out, target, sizeof(git_stream_registration)); stream_registration_cpy(out, target);
error = 0; error = 0;
} }
...@@ -59,12 +79,8 @@ int git_stream_registry_lookup(git_stream_registration *out, int tls) ...@@ -59,12 +79,8 @@ int git_stream_registry_lookup(git_stream_registration *out, int tls)
return error; return error;
} }
int git_stream_register(int tls, git_stream_registration *registration) int git_stream_register(git_stream_t type, git_stream_registration *registration)
{ {
git_stream_registration *target = tls ?
&stream_registry.callbacks :
&stream_registry.tls_callbacks;
assert(!registration || registration->init); assert(!registration || registration->init);
GITERR_CHECK_VERSION(registration, GIT_STREAM_VERSION, "stream_registration"); GITERR_CHECK_VERSION(registration, GIT_STREAM_VERSION, "stream_registration");
...@@ -74,10 +90,11 @@ int git_stream_register(int tls, git_stream_registration *registration) ...@@ -74,10 +90,11 @@ int git_stream_register(int tls, git_stream_registration *registration)
return -1; return -1;
} }
if (registration) if ((type & GIT_STREAM_STANDARD) == GIT_STREAM_STANDARD)
memcpy(target, registration, sizeof(git_stream_registration)); stream_registration_cpy(&stream_registry.callbacks, registration);
else
memset(target, 0, sizeof(git_stream_registration)); if ((type & GIT_STREAM_TLS) == GIT_STREAM_TLS)
stream_registration_cpy(&stream_registry.tls_callbacks, registration);
git_rwlock_wrunlock(&stream_registry.lock); git_rwlock_wrunlock(&stream_registry.lock);
return 0; return 0;
...@@ -92,8 +109,8 @@ int git_stream_register_tls(git_stream_cb ctor) ...@@ -92,8 +109,8 @@ int git_stream_register_tls(git_stream_cb ctor)
registration.init = ctor; registration.init = ctor;
registration.wrap = NULL; registration.wrap = NULL;
return git_stream_register(1, &registration); return git_stream_register(GIT_STREAM_TLS, &registration);
} else { } else {
return git_stream_register(1, NULL); return git_stream_register(GIT_STREAM_TLS, NULL);
} }
} }
...@@ -14,6 +14,6 @@ ...@@ -14,6 +14,6 @@
int git_stream_registry_global_init(void); int git_stream_registry_global_init(void);
/** Lookup a stream registration. */ /** Lookup a stream registration. */
extern int git_stream_registry_lookup(git_stream_registration *out, int tls); extern int git_stream_registry_lookup(git_stream_registration *out, git_stream_t type);
#endif #endif
...@@ -224,7 +224,7 @@ int git_socket_stream_new( ...@@ -224,7 +224,7 @@ int git_socket_stream_new(
assert(out && host && port); assert(out && host && port);
if ((error = git_stream_registry_lookup(&custom, 0)) == 0) if ((error = git_stream_registry_lookup(&custom, GIT_STREAM_STANDARD)) == 0)
init = custom.init; init = custom.init;
else if (error == GIT_ENOTFOUND) else if (error == GIT_ENOTFOUND)
init = default_socket_stream_new; init = default_socket_stream_new;
......
...@@ -23,7 +23,7 @@ int git_tls_stream_new(git_stream **out, const char *host, const char *port) ...@@ -23,7 +23,7 @@ int git_tls_stream_new(git_stream **out, const char *host, const char *port)
assert(out && host && port); assert(out && host && port);
if ((error = git_stream_registry_lookup(&custom, 1)) == 0) { if ((error = git_stream_registry_lookup(&custom, GIT_STREAM_TLS)) == 0) {
init = custom.init; init = custom.init;
} else if (error == GIT_ENOTFOUND) { } else if (error == GIT_ENOTFOUND) {
#ifdef GIT_SECURE_TRANSPORT #ifdef GIT_SECURE_TRANSPORT
...@@ -52,7 +52,7 @@ int git_tls_stream_wrap(git_stream **out, git_stream *in, const char *host) ...@@ -52,7 +52,7 @@ int git_tls_stream_wrap(git_stream **out, git_stream *in, const char *host)
assert(out && in); assert(out && in);
if (git_stream_registry_lookup(&custom, 1) == 0) { if (git_stream_registry_lookup(&custom, GIT_STREAM_TLS) == 0) {
wrap = custom.wrap; wrap = custom.wrap;
} else { } else {
#ifdef GIT_SECURE_TRANSPORT #ifdef GIT_SECURE_TRANSPORT
......
...@@ -7,6 +7,11 @@ ...@@ -7,6 +7,11 @@
static git_stream test_stream; static git_stream test_stream;
static int ctor_called; static int ctor_called;
void test_core_stream__cleanup(void)
{
cl_git_pass(git_stream_register(GIT_STREAM_TLS | GIT_STREAM_STANDARD, NULL));
}
static int test_stream_init(git_stream **out, const char *host, const char *port) static int test_stream_init(git_stream **out, const char *host, const char *port)
{ {
GIT_UNUSED(host); GIT_UNUSED(host);
...@@ -39,14 +44,14 @@ void test_core_stream__register_insecure(void) ...@@ -39,14 +44,14 @@ void test_core_stream__register_insecure(void)
registration.wrap = test_stream_wrap; registration.wrap = test_stream_wrap;
ctor_called = 0; ctor_called = 0;
cl_git_pass(git_stream_register(0, &registration)); cl_git_pass(git_stream_register(GIT_STREAM_STANDARD, &registration));
cl_git_pass(git_socket_stream_new(&stream, "localhost", "80")); cl_git_pass(git_socket_stream_new(&stream, "localhost", "80"));
cl_assert_equal_i(1, ctor_called); cl_assert_equal_i(1, ctor_called);
cl_assert_equal_p(&test_stream, stream); cl_assert_equal_p(&test_stream, stream);
ctor_called = 0; ctor_called = 0;
stream = NULL; stream = NULL;
cl_git_pass(git_stream_register(0, NULL)); cl_git_pass(git_stream_register(GIT_STREAM_STANDARD, NULL));
cl_git_pass(git_socket_stream_new(&stream, "localhost", "80")); cl_git_pass(git_socket_stream_new(&stream, "localhost", "80"));
cl_assert_equal_i(0, ctor_called); cl_assert_equal_i(0, ctor_called);
...@@ -66,14 +71,14 @@ void test_core_stream__register_tls(void) ...@@ -66,14 +71,14 @@ void test_core_stream__register_tls(void)
registration.wrap = test_stream_wrap; registration.wrap = test_stream_wrap;
ctor_called = 0; ctor_called = 0;
cl_git_pass(git_stream_register(1, &registration)); cl_git_pass(git_stream_register(GIT_STREAM_TLS, &registration));
cl_git_pass(git_tls_stream_new(&stream, "localhost", "443")); cl_git_pass(git_tls_stream_new(&stream, "localhost", "443"));
cl_assert_equal_i(1, ctor_called); cl_assert_equal_i(1, ctor_called);
cl_assert_equal_p(&test_stream, stream); cl_assert_equal_p(&test_stream, stream);
ctor_called = 0; ctor_called = 0;
stream = NULL; stream = NULL;
cl_git_pass(git_stream_register(1, NULL)); cl_git_pass(git_stream_register(GIT_STREAM_TLS, NULL));
error = git_tls_stream_new(&stream, "localhost", "443"); error = git_tls_stream_new(&stream, "localhost", "443");
/* We don't have TLS support enabled, or we're on Windows, /* We don't have TLS support enabled, or we're on Windows,
...@@ -91,6 +96,28 @@ void test_core_stream__register_tls(void) ...@@ -91,6 +96,28 @@ void test_core_stream__register_tls(void)
git_stream_free(stream); git_stream_free(stream);
} }
void test_core_stream__register_both(void)
{
git_stream *stream;
git_stream_registration registration = {0};
registration.version = 1;
registration.init = test_stream_init;
registration.wrap = test_stream_wrap;
cl_git_pass(git_stream_register(GIT_STREAM_STANDARD | GIT_STREAM_TLS, &registration));
ctor_called = 0;
cl_git_pass(git_tls_stream_new(&stream, "localhost", "443"));
cl_assert_equal_i(1, ctor_called);
cl_assert_equal_p(&test_stream, stream);
ctor_called = 0;
cl_git_pass(git_socket_stream_new(&stream, "localhost", "80"));
cl_assert_equal_i(1, ctor_called);
cl_assert_equal_p(&test_stream, stream);
}
void test_core_stream__register_tls_deprecated(void) void test_core_stream__register_tls_deprecated(void)
{ {
git_stream *stream; git_stream *stream;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment