@@ -18,8 +18,6 @@ limitations under the License.
1818#include < stdio.h>
1919#include < stdlib.h>
2020
21- // IMPORTANT: explicit STL includes (this .cc is included via tfj_gradients.h in some builds,
22- // so we must not rely on transitive includes from other headers).
2321#include < string>
2422#include < unordered_map>
2523#include < vector>
@@ -31,136 +29,128 @@ limitations under the License.
3129#include " tensorflow/cc/framework/grad_op_registry.h"
3230
3331namespace tensorflow {
34- namespace java {
35-
36- using namespace tsl ;
37- using namespace std ;
38-
39- unordered_map<string, TFJ_GradFuncAdapter> g_grad_func_adapters;
40-
41- // Cast helper (inspired by TF C-API)
42- template <typename T, typename U>
43- T* struct_cast (U* ptr) {
44- return static_cast <T*>(static_cast <void *>(ptr));
45- }
46-
47- // Bridge called by TF runtime when building gradients for op
48- Status CustomGradFunc (const Scope& scope,
49- const Operation& op,
50- const vector<Output>& grad_inputs,
51- vector<Output>* grad_outputs) {
52- const string& op_type = op.node ()->type_string ();
53- auto found_adapter = g_grad_func_adapters.find (op_type);
54- if (found_adapter == g_grad_func_adapters.end ()) {
55- return errors::NotFound (" No gradient adapter found for operation " , op_type);
56- }
57-
58- TFJ_GradFuncAdapter adapter = found_adapter->second ;
59- if (adapter == nullptr ) {
60- return errors::Unknown (" Null Java gradient adapter for op " , op_type);
61- }
62-
63- const int num_inputs = static_cast <int >(grad_inputs.size ());
64-
65- TF_Output* inputs = nullptr ;
66- if (num_inputs > 0 ) {
67- inputs = static_cast <TF_Output*>(malloc (num_inputs * sizeof (TF_Output)));
68- if (inputs == nullptr ) {
69- return errors::ResourceExhausted (
70- " Out of memory allocating inputs for custom gradient of op " , op_type);
71- }
72- }
73-
74- for (int i = 0 ; i < num_inputs; ++i) {
75- const Output& grad_input = grad_inputs[i];
76- inputs[i].oper = struct_cast<TF_Operation>(grad_input.node ());
77- inputs[i].index = grad_input.index ();
78- }
79-
80- TF_Output* outputs = nullptr ;
81-
82- LOG (INFO) << " Calling Java gradient function for operation of type " << op_type;
83- const int num_outputs = adapter (
84- static_cast <TFJ_GraphId>(scope.graph ()),
85- struct_cast<TFJ_Scope>(const_cast <Scope*>(&scope)),
86- struct_cast<TF_Operation>(op.node ()),
87- inputs,
88- num_inputs,
89- &outputs);
90-
91- if (inputs != nullptr ) free (inputs);
92-
93- // Adapter contract:
94- // - num_outputs < 0 indicates failure
95- // - num_outputs == 0: OK, outputs may be nullptr
96- // - num_outputs > 0: outputs must be non-null
97- if (num_outputs < 0 ) {
98- if (outputs != nullptr ) free (outputs);
99- return errors::Unknown (" Java custom gradient adapter failed for op " , op_type,
100- " (num_outputs=" , num_outputs, " )" );
101- }
102- if (num_outputs > 0 && outputs == nullptr ) {
103- return errors::Unknown (" Java custom gradient adapter returned null outputs for op " ,
104- op_type, " with num_outputs=" , num_outputs);
105- }
106-
107- grad_outputs->reserve (grad_outputs->size () + static_cast <size_t >(num_outputs));
108-
109- for (int i = 0 ; i < num_outputs; ++i) {
110- const TF_Output out = outputs[i];
111-
112- // Convention: out.oper == nullptr => NoGradient
113- if (out.oper == nullptr ) {
114- grad_outputs->push_back (Output ()); // TF interprets empty Output as "no grad"
115- continue ;
32+ namespace java {
33+ using namespace tsl ;
34+ using namespace std ;
35+
36+ unordered_map<string, TFJ_GradFuncAdapter> g_grad_func_adapters;
37+
38+ // / This method can be used to cast a pointer to/from a C struct that contains only that pointer. It is a bit
39+ // /
40+ // / It has been "inspired" by the TensorFlow C API code, as found at this location when time of writing:
41+ // / https://github.com/tensorflow/tensorflow/blob/9d637f69f699c0c422716b56153a8b27b681891a/tensorflow/c/c_api.cc#L658
42+ template <typename T, typename U> T* struct_cast (U* ptr) {
43+ return static_cast <T*>(static_cast <void *>(ptr));
44+ }
45+
46+ // / This function is called by the TensorFlow runtime when it is time to add gradient operations of `op` to the
47+ // / graph using the given `scope`.
48+ // / We use it as a bridge between the C++ signature in TensorFlow (tensorflow::op::GradFunc) and our custom
49+ // / "C" version (TFJ_GradFuncAdapter).
50+ Status CustomGradFunc (const Scope& scope,
51+ const Operation& op,
52+ const vector<Output>& grad_inputs,
53+ vector<Output>* grad_outputs)
54+ {
55+ const string& op_type = op.node ()->type_string ();
56+ auto found_adapter = g_grad_func_adapters.find (op_type);
57+ if (found_adapter == g_grad_func_adapters.end ()) {
58+ return errors::NotFound (" No gradient adapter found for operation " , op_type);
59+ }
60+
61+ TFJ_GradFuncAdapter adapter = found_adapter->second ;
62+ if (adapter == NULL ) {
63+ return errors::Unknown (" Null Java gradient adapter for operation " , op_type);
64+ }
65+
66+ int num_inputs = grad_inputs.size ();
67+ TF_Output* inputs = NULL ;
68+ if (num_inputs > 0 ) {
69+ inputs = (TF_Output*)malloc (num_inputs * sizeof (TF_Output));
70+ if (inputs == NULL ) {
71+ return errors::ResourceExhausted (
72+ " Out of memory allocating inputs for custom gradient of op " , op_type);
73+ }
74+ }
75+
76+ for (int i = 0 ; i < num_inputs; ++i) {
77+ Output grad_input = grad_inputs[i];
78+ inputs[i].oper = struct_cast<TF_Operation>(grad_input.node ());
79+ inputs[i].index = grad_input.index ();
80+ }
81+
82+ TF_Output* outputs = NULL ;
83+ LOG (INFO) << " Calling Java gradient function for operation of type " << op_type;
84+ int num_outputs = adapter (
85+ static_cast <TFJ_GraphId>(scope.graph ()),
86+ struct_cast<TFJ_Scope>(const_cast <Scope*>(&scope)),
87+ struct_cast<TF_Operation>(op.node ()),
88+ inputs,
89+ num_inputs,
90+ &outputs
91+ );
92+
93+ if (inputs != NULL ) free (inputs);
94+
95+ if (num_outputs < 0 ) {
96+ if (outputs != NULL ) free (outputs);
97+ return errors::Unknown (" Java custom gradient adapter failed for operation " , op_type,
98+ " (num_outputs=" , num_outputs, " )" );
99+ }
100+ if (num_outputs > 0 && outputs == NULL ) {
101+ return errors::Unknown (" Java custom gradient adapter returned null outputs for operation " ,
102+ op_type, " with num_outputs=" , num_outputs);
103+ }
104+
105+ for (int i = 0 ; i < num_outputs; ++i) {
106+ TF_Output output = outputs[i];
107+
108+ // Convention: output.oper == NULL => NoGradient
109+ if (output.oper == NULL ) {
110+ grad_outputs->push_back (Output ());
111+ } else {
112+ grad_outputs->push_back (Output (struct_cast<Node>(output.oper ), output.index ));
113+ }
114+ }
115+
116+ if (outputs != NULL ) free (outputs); // outputs are allocated from Java but must be freed here
117+ return OkStatus ();
118+ }
116119 }
117-
118- grad_outputs->push_back (Output (struct_cast<Node>(out.oper ), out.index ));
119- }
120-
121- if (outputs != nullptr ) free (outputs); // allocated from Java via malloc
122- return OkStatus ();
123120}
124121
125- } // namespace java
126- } // namespace tensorflow
127-
128122using namespace tensorflow ::ops;
129123using namespace tensorflow ::java;
130124
131125bool TFJ_HasGradient (const char * op_type) {
132- GradFunc dummy;
133- tsl::Status status = GradOpRegistry::Global ()->Lookup (op_type, &dummy);
134- return status.ok ();
126+ GradFunc dummy;
127+ tsl::Status status = GradOpRegistry::Global ()->Lookup (op_type, &dummy);
128+ return status.ok ();
135129}
136130
137131bool TFJ_RegisterCustomGradient (const char * op_type, TFJ_GradFuncAdapter grad_func_adapter) {
138- LOG (INFO) << " TFJ_RegisterCustomGradient(" << op_type << " ) adapter_ptr="
139- << reinterpret_cast <void *>(grad_func_adapter);
140-
141- if (grad_func_adapter == nullptr ) {
142- LOG (ERROR) << " Refusing to register NULL Java gradient adapter for op " << op_type;
143- return false ;
144- }
145-
146- if (TFJ_HasGradient (op_type)) {
147- LOG (WARNING) << " Tried to register Java gradient function for operation " << op_type
148- << " , which has already a registered function" ;
149- return false ;
150- }
151-
152- bool registered = GradOpRegistry::Global ()->Register (op_type, CustomGradFunc);
153- if (registered) {
154- g_grad_func_adapters.insert ({op_type, grad_func_adapter});
155- }
156- return registered;
132+ if (grad_func_adapter == NULL ) {
133+ LOG (ERROR) << " Refusing to register NULL Java gradient adapter for operation " << op_type;
134+ return false ;
135+ }
136+
137+ if (TFJ_HasGradient (op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash
138+ LOG (WARNING) << " Tried to register Java gradient function for operation " << op_type
139+ << " , which has already a registered function" ;
140+ return false ;
141+ }
142+ bool registered = GradOpRegistry::Global ()->Register (op_type, CustomGradFunc);
143+ if (registered) {
144+ g_grad_func_adapters.insert ({op_type, grad_func_adapter});
145+ }
146+ return registered;
157147}
158148
159- #else // _WIN32
149+ #else // #ifndef _WIN32
150+
151+ /* This extension is not available on Windows */
160152
161153bool TFJ_HasGradient (const char * op_type) { return true ; }
162- bool TFJ_RegisterCustomGradient (const char * op_type, TFJ_GradFuncAdapter grad_func_adapter) {
163- return false ;
164- }
154+ bool TFJ_RegisterCustomGradient (const char * op_type, TFJ_GradFuncAdapter grad_func_adapter) { return false ; }
165155
166- #endif // _WIN32
156+ #endif // #ifndef _WIN32
0 commit comments