-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlstm.h
68 lines (50 loc) · 1.57 KB
/
lstm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#ifndef mj_lstm_h
#define mj_lstm_h
#include <stddef.h>
#include <immintrin.h>
#include <fmaintrin.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef double real;
typedef __m256d realv;
typedef struct slstm
{
size_t num_inputs;
size_t num_outputs;
size_t num_hidden_layers;
size_t* num_hiddens;
size_t max_state_size;
size_t max_activation_state_size;
#define MEMBER(name) union { realv* v; real* i; } name;
MEMBER(memory_cells);
MEMBER(weights_input);
MEMBER(weights_input_gate);
MEMBER(weights_forget_gate);
MEMBER(weights_output_gate);
MEMBER(feedback_activations);
MEMBER(weights_feedback_input);
MEMBER(weights_feedback_input_gate);
MEMBER(weights_feedback_forget_gate);
MEMBER(weights_feedback_output_gate);
MEMBER(weights_last_layer);
MEMBER(bias_input);
MEMBER(bias_input_gate);
MEMBER(bias_forget_gate);
MEMBER(bias_output_gate);
MEMBER(weights_memory_to_output_gate);
MEMBER(weights_memory_to_input_gate);
MEMBER(weights_memory_to_forget_gate);
} lstm;
/* Check README.md for documentation on these functions. */
void propagate(lstm* restrict lstm, const real* restrict input, real* restrict output);
lstm* allocate_lstm( size_t num_inputs, size_t num_outputs, const size_t* num_hiddens, size_t num_hidden_layers );
void free_lstm( lstm* lstm );
void reset_lstm( lstm* lstm );
size_t serialized_lstm_size( const lstm* lstm );
void serialize_lstm( const lstm* lstm, real* serialized );
void unserialize_lstm( lstm* lstm, const real* serialized );
#ifdef __cplusplus
}
#endif
#endif