Skip to content

Commit 8d80312

Browse files
committed
apply mvn spotless
1 parent 69a9a36 commit 8d80312

3 files changed

Lines changed: 116 additions & 127 deletions

File tree

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ private static List<Output<?>> fromNativeOutputs(Graph g, TF_Output nativeOutput
8080
}
8181

8282
/**
83-
* Put the Java outputs into the array of native outputs, resizing it to the necessary size.
84-
*
85-
* @param outputs the outputs to put
86-
* @return pointer to the native array of outputs
87-
*/
83+
* Put the Java outputs into the array of native outputs, resizing it to the necessary size.
84+
*
85+
* @param outputs the outputs to put
86+
* @return pointer to the native array of outputs
87+
*/
8888
private static TF_Output toNativeOutputs(List<Operand<?>> outputs) {
8989
// Use malloc to allocate native outputs, as they will be freed by the native layer and we do
9090
// not want JavaCPP to deallocate them

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientsTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
package org.tensorflow;
22

33
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertFalse;
45
import static org.junit.jupiter.api.Assertions.assertNotNull;
56
import static org.junit.jupiter.api.Assertions.assertTrue;
6-
import static org.junit.jupiter.api.Assertions.assertFalse;
77

8-
import java.util.List;
98
import org.junit.jupiter.api.Test;
109
import org.junit.jupiter.api.condition.DisabledOnOs;
1110
import org.junit.jupiter.api.condition.OS;

tensorflow-core/tensorflow-core-native/src/main/native/org/tensorflow/internal/c_api/tfj_gradients_impl.cc

Lines changed: 110 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ limitations under the License.
1818
#include <stdio.h>
1919
#include <stdlib.h>
2020

21-
// IMPORTANT: explicit STL includes (this .cc is included via tfj_gradients.h in some builds,
22-
// so we must not rely on transitive includes from other headers).
2321
#include <string>
2422
#include <unordered_map>
2523
#include <vector>
@@ -31,136 +29,128 @@ limitations under the License.
3129
#include "tensorflow/cc/framework/grad_op_registry.h"
3230

3331
namespace tensorflow {
34-
namespace java {
35-
36-
using namespace tsl;
37-
using namespace std;
38-
39-
unordered_map<string, TFJ_GradFuncAdapter> g_grad_func_adapters;
40-
41-
// Cast helper (inspired by TF C-API)
42-
template <typename T, typename U>
43-
T* struct_cast(U* ptr) {
44-
return static_cast<T*>(static_cast<void*>(ptr));
45-
}
46-
47-
// Bridge called by TF runtime when building gradients for op
48-
Status CustomGradFunc(const Scope& scope,
49-
const Operation& op,
50-
const vector<Output>& grad_inputs,
51-
vector<Output>* grad_outputs) {
52-
const string& op_type = op.node()->type_string();
53-
auto found_adapter = g_grad_func_adapters.find(op_type);
54-
if (found_adapter == g_grad_func_adapters.end()) {
55-
return errors::NotFound("No gradient adapter found for operation ", op_type);
56-
}
57-
58-
TFJ_GradFuncAdapter adapter = found_adapter->second;
59-
if (adapter == nullptr) {
60-
return errors::Unknown("Null Java gradient adapter for op ", op_type);
61-
}
62-
63-
const int num_inputs = static_cast<int>(grad_inputs.size());
64-
65-
TF_Output* inputs = nullptr;
66-
if (num_inputs > 0) {
67-
inputs = static_cast<TF_Output*>(malloc(num_inputs * sizeof(TF_Output)));
68-
if (inputs == nullptr) {
69-
return errors::ResourceExhausted(
70-
"Out of memory allocating inputs for custom gradient of op ", op_type);
71-
}
72-
}
73-
74-
for (int i = 0; i < num_inputs; ++i) {
75-
const Output& grad_input = grad_inputs[i];
76-
inputs[i].oper = struct_cast<TF_Operation>(grad_input.node());
77-
inputs[i].index = grad_input.index();
78-
}
79-
80-
TF_Output* outputs = nullptr;
81-
82-
LOG(INFO) << "Calling Java gradient function for operation of type " << op_type;
83-
const int num_outputs = adapter(
84-
static_cast<TFJ_GraphId>(scope.graph()),
85-
struct_cast<TFJ_Scope>(const_cast<Scope*>(&scope)),
86-
struct_cast<TF_Operation>(op.node()),
87-
inputs,
88-
num_inputs,
89-
&outputs);
90-
91-
if (inputs != nullptr) free(inputs);
92-
93-
// Adapter contract:
94-
// - num_outputs < 0 indicates failure
95-
// - num_outputs == 0: OK, outputs may be nullptr
96-
// - num_outputs > 0: outputs must be non-null
97-
if (num_outputs < 0) {
98-
if (outputs != nullptr) free(outputs);
99-
return errors::Unknown("Java custom gradient adapter failed for op ", op_type,
100-
" (num_outputs=", num_outputs, ")");
101-
}
102-
if (num_outputs > 0 && outputs == nullptr) {
103-
return errors::Unknown("Java custom gradient adapter returned null outputs for op ",
104-
op_type, " with num_outputs=", num_outputs);
105-
}
106-
107-
grad_outputs->reserve(grad_outputs->size() + static_cast<size_t>(num_outputs));
108-
109-
for (int i = 0; i < num_outputs; ++i) {
110-
const TF_Output out = outputs[i];
111-
112-
// Convention: out.oper == nullptr => NoGradient
113-
if (out.oper == nullptr) {
114-
grad_outputs->push_back(Output()); // TF interprets empty Output as "no grad"
115-
continue;
32+
namespace java {
33+
using namespace tsl;
34+
using namespace std;
35+
36+
unordered_map<string, TFJ_GradFuncAdapter> g_grad_func_adapters;
37+
38+
/// This method can be used to cast a pointer to/from a C struct that contains only that pointer. It is a bit
39+
///
40+
/// It has been "inspired" by the TensorFlow C API code, as found at this location when time of writing:
41+
/// https://github.com/tensorflow/tensorflow/blob/9d637f69f699c0c422716b56153a8b27b681891a/tensorflow/c/c_api.cc#L658
42+
template <typename T, typename U> T* struct_cast(U* ptr) {
43+
return static_cast<T*>(static_cast<void*>(ptr));
44+
}
45+
46+
/// This function is called by the TensorFlow runtime when it is time to add gradient operations of `op` to the
47+
/// graph using the given `scope`.
48+
/// We use it as a bridge between the C++ signature in TensorFlow (tensorflow::op::GradFunc) and our custom
49+
/// "C" version (TFJ_GradFuncAdapter).
50+
Status CustomGradFunc(const Scope& scope,
51+
const Operation& op,
52+
const vector<Output>& grad_inputs,
53+
vector<Output>* grad_outputs)
54+
{
55+
const string& op_type = op.node()->type_string();
56+
auto found_adapter = g_grad_func_adapters.find(op_type);
57+
if (found_adapter == g_grad_func_adapters.end()) {
58+
return errors::NotFound("No gradient adapter found for operation ", op_type);
59+
}
60+
61+
TFJ_GradFuncAdapter adapter = found_adapter->second;
62+
if (adapter == NULL) {
63+
return errors::Unknown("Null Java gradient adapter for operation ", op_type);
64+
}
65+
66+
int num_inputs = grad_inputs.size();
67+
TF_Output* inputs = NULL;
68+
if (num_inputs > 0) {
69+
inputs = (TF_Output*)malloc(num_inputs * sizeof(TF_Output));
70+
if (inputs == NULL) {
71+
return errors::ResourceExhausted(
72+
"Out of memory allocating inputs for custom gradient of op ", op_type);
73+
}
74+
}
75+
76+
for (int i = 0; i < num_inputs; ++i) {
77+
Output grad_input = grad_inputs[i];
78+
inputs[i].oper = struct_cast<TF_Operation>(grad_input.node());
79+
inputs[i].index = grad_input.index();
80+
}
81+
82+
TF_Output* outputs = NULL;
83+
LOG(INFO) << "Calling Java gradient function for operation of type " << op_type;
84+
int num_outputs = adapter(
85+
static_cast<TFJ_GraphId>(scope.graph()),
86+
struct_cast<TFJ_Scope>(const_cast<Scope*>(&scope)),
87+
struct_cast<TF_Operation>(op.node()),
88+
inputs,
89+
num_inputs,
90+
&outputs
91+
);
92+
93+
if (inputs != NULL) free(inputs);
94+
95+
if (num_outputs < 0) {
96+
if (outputs != NULL) free(outputs);
97+
return errors::Unknown("Java custom gradient adapter failed for operation ", op_type,
98+
" (num_outputs=", num_outputs, ")");
99+
}
100+
if (num_outputs > 0 && outputs == NULL) {
101+
return errors::Unknown("Java custom gradient adapter returned null outputs for operation ",
102+
op_type, " with num_outputs=", num_outputs);
103+
}
104+
105+
for (int i = 0; i < num_outputs; ++i) {
106+
TF_Output output = outputs[i];
107+
108+
// Convention: output.oper == NULL => NoGradient
109+
if (output.oper == NULL) {
110+
grad_outputs->push_back(Output());
111+
} else {
112+
grad_outputs->push_back(Output(struct_cast<Node>(output.oper), output.index));
113+
}
114+
}
115+
116+
if (outputs != NULL) free(outputs); // outputs are allocated from Java but must be freed here
117+
return OkStatus();
118+
}
116119
}
117-
118-
grad_outputs->push_back(Output(struct_cast<Node>(out.oper), out.index));
119-
}
120-
121-
if (outputs != nullptr) free(outputs); // allocated from Java via malloc
122-
return OkStatus();
123120
}
124121

125-
} // namespace java
126-
} // namespace tensorflow
127-
128122
using namespace tensorflow::ops;
129123
using namespace tensorflow::java;
130124

131125
bool TFJ_HasGradient(const char* op_type) {
132-
GradFunc dummy;
133-
tsl::Status status = GradOpRegistry::Global()->Lookup(op_type, &dummy);
134-
return status.ok();
126+
GradFunc dummy;
127+
tsl::Status status = GradOpRegistry::Global()->Lookup(op_type, &dummy);
128+
return status.ok();
135129
}
136130

137131
bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) {
138-
LOG(INFO) << "TFJ_RegisterCustomGradient(" << op_type << ") adapter_ptr="
139-
<< reinterpret_cast<void*>(grad_func_adapter);
140-
141-
if (grad_func_adapter == nullptr) {
142-
LOG(ERROR) << "Refusing to register NULL Java gradient adapter for op " << op_type;
143-
return false;
144-
}
145-
146-
if (TFJ_HasGradient(op_type)) {
147-
LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type
148-
<< ", which has already a registered function";
149-
return false;
150-
}
151-
152-
bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc);
153-
if (registered) {
154-
g_grad_func_adapters.insert({op_type, grad_func_adapter});
155-
}
156-
return registered;
132+
if (grad_func_adapter == NULL) {
133+
LOG(ERROR) << "Refusing to register NULL Java gradient adapter for operation " << op_type;
134+
return false;
135+
}
136+
137+
if (TFJ_HasGradient(op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash
138+
LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type
139+
<< ", which has already a registered function";
140+
return false;
141+
}
142+
bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc);
143+
if (registered) {
144+
g_grad_func_adapters.insert({op_type, grad_func_adapter});
145+
}
146+
return registered;
157147
}
158148

159-
#else // _WIN32
149+
#else // #ifndef _WIN32
150+
151+
/* This extension is not available on Windows */
160152

161153
bool TFJ_HasGradient(const char* op_type) { return true; }
162-
bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) {
163-
return false;
164-
}
154+
bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) { return false; }
165155

166-
#endif // _WIN32
156+
#endif // #ifndef _WIN32

0 commit comments

Comments
 (0)