package com.google.cloud.vertexai.generativeai;

import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

/* loaded from: input_file:com/google/cloud/vertexai/generativeai/AutomaticFunctionCallingResponder.class */
public final class AutomaticFunctionCallingResponder {
    private int maxFunctionCalls;
    private int remainingFunctionCalls;
    private final Map<String, CallableFunction> callableFunctions;
    private static final Logger logger = Logger.getLogger(AutomaticFunctionCallingResponder.class.getName());

    /* loaded from: input_file:com/google/cloud/vertexai/generativeai/AutomaticFunctionCallingResponder$CallableFunction.class */
    static class CallableFunction {
        private final Method callableFunction;
        private final ImmutableList<String> orderedParameterNames;

        CallableFunction(Method method, String... strArr) {
            validateFunction(method);
            this.callableFunction = method;
            if (strArr.length != 0) {
                if (strArr.length != method.getParameters().length) {
                    throw new IllegalArgumentException("The number of provided parameter names doesn't match the number of parameters in the callable function.");
                }
                this.orderedParameterNames = ImmutableList.copyOf(strArr);
                return;
            }
            ImmutableList.Builder builder = ImmutableList.builder();
            for (Parameter parameter : method.getParameters()) {
                if (!parameter.isNamePresent()) {
                    throw new IllegalStateException("Failed to retrieve the parameter name from reflection. Please compile your code with  \"-parameters\" flag or use `addCallableFunction(String, Method, String...)` to manually enter parameter names");
                }
                builder.add((ImmutableList.Builder) parameter.getName());
            }
            this.orderedParameterNames = builder.build();
        }

        Object call(Struct struct) {
            Map<String, Value> fieldsMap = struct.getFieldsMap();
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < this.orderedParameterNames.size(); i++) {
                String str = this.orderedParameterNames.get(i);
                if (!fieldsMap.containsKey(str)) {
                    throw new IllegalArgumentException("The parameter \"" + str + "\" was not found in the arguments requested by the model. Args map: " + fieldsMap);
                }
                Value value = fieldsMap.get(str);
                switch (value.getKindCase()) {
                    case NUMBER_VALUE:
                        Class<?> type = this.callableFunction.getParameters()[i].getType();
                        if (type.equals(Integer.TYPE)) {
                            arrayList.add(Integer.valueOf((int) value.getNumberValue()));
                            break;
                        } else if (type.equals(Float.TYPE)) {
                            arrayList.add(Float.valueOf((float) value.getNumberValue()));
                            break;
                        } else {
                            arrayList.add(Double.valueOf(value.getNumberValue()));
                            break;
                        }
                    case STRING_VALUE:
                        arrayList.add(value.getStringValue());
                        break;
                    case BOOL_VALUE:
                        arrayList.add(Boolean.valueOf(value.getBoolValue()));
                        break;
                    case NULL_VALUE:
                        arrayList.add(null);
                        break;
                    default:
                        throw new IllegalArgumentException("Unsupported value type " + value.getKindCase() + " for parameter " + str);
                }
            }
            AutomaticFunctionCallingResponder.logger.info("Automatically calling function: " + this.callableFunction.getName() + arrayList.toString().replace('[', '(').replace(']', ')'));
            try {
                return this.callableFunction.invoke(null, arrayList.toArray());
            } catch (Exception e) {
                throw new IllegalStateException("Error raised when calling function \"" + this.callableFunction.getName() + "\" as requested by the model. ", e);
            }
        }

        private void validateFunction(Method method) {
            if (!Modifier.isStatic(method.getModifiers())) {
                throw new IllegalArgumentException("Function calling only supports static methods.");
            }
        }
    }

    public AutomaticFunctionCallingResponder() {
        this.maxFunctionCalls = 1;
        this.callableFunctions = new HashMap();
        this.remainingFunctionCalls = this.maxFunctionCalls;
    }

    public AutomaticFunctionCallingResponder(int i) {
        this.maxFunctionCalls = 1;
        this.callableFunctions = new HashMap();
        this.maxFunctionCalls = i;
        this.remainingFunctionCalls = i;
    }

    public void setMaxFunctionCalls(int i) {
        this.maxFunctionCalls = i;
        this.remainingFunctionCalls = this.maxFunctionCalls;
    }

    public int getMaxFunctionCalls() {
        return this.maxFunctionCalls;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void resetRemainingFunctionCalls() {
        this.remainingFunctionCalls = this.maxFunctionCalls;
    }

    public void addCallableFunction(String str, Method method, String... strArr) {
        if (this.callableFunctions.containsKey(str)) {
            throw new IllegalArgumentException("Duplicate function name: " + str);
        }
        this.callableFunctions.put(str, new CallableFunction(method, strArr));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Content getContentFromFunctionCalls(List<FunctionCall> list) {
        Preconditions.checkNotNull(list, "functionCalls cannot be null.");
        ArrayList arrayList = new ArrayList();
        for (FunctionCall functionCall : list) {
            if (this.remainingFunctionCalls <= 0) {
                throw new IllegalStateException("Exceeded the maximum number of continuous automatic function calls (" + this.maxFunctionCalls + "). If more automatic function calls are needed, please call `setMaxFunctionCalls() to set a higher number. The last function call is:\n" + functionCall);
            }
            this.remainingFunctionCalls--;
            String name = functionCall.getName();
            CallableFunction callableFunction = this.callableFunctions.get(name);
            if (callableFunction == null) {
                throw new IllegalArgumentException("Model has asked to call function \"" + name + "\" which was not found.");
            }
            arrayList.add(PartMaker.fromFunctionResponse(name, (Map<String, Object>) Collections.singletonMap("result", callableFunction.call(functionCall.getArgs()))));
        }
        return ContentMaker.fromMultiModalData(arrayList.toArray());
    }
}
